diff --git a/pyproject.toml b/pyproject.toml index 921d4d1..0dbfbda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,3 +47,12 @@ packages = ["src/geo_assistant"] [tool.pytest.ini_options] addopts = "--color=yes" asyncio_mode = "auto" + +[tool.ruff.lint] +select = [ + "COM", # flake8-commas + "F", # pyflakes + "I", # isort + "RUF", # ruff-specific + "UP", # pyupgrade +] diff --git a/src/geo_assistant/agent/graph.py b/src/geo_assistant/agent/graph.py index 96d7fb4..f88fa2f 100644 --- a/src/geo_assistant/agent/graph.py +++ b/src/geo_assistant/agent/graph.py @@ -1,12 +1,13 @@ import datetime -from langgraph.checkpoint.memory import InMemorySaver from langchain.agents import create_agent -from geo_assistant.agent.state import GeoAssistantState +from langgraph.checkpoint.memory import InMemorySaver + from geo_assistant.agent.llms import llm +from geo_assistant.agent.state import GeoAssistantState +from geo_assistant.tools.buffer import get_search_area from geo_assistant.tools.naip import fetch_naip_img from geo_assistant.tools.overture import get_place -from geo_assistant.tools.buffer import get_search_area from geo_assistant.tools.summarize import summarize_sat_img SYSTEM_PROMPT = """ @@ -36,7 +37,7 @@ async def create_graph(): summarize_sat_img, ], system_prompt=SYSTEM_PROMPT.format( - now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), ), state_schema=GeoAssistantState, checkpointer=checkpointer, diff --git a/src/geo_assistant/agent/llms.py b/src/geo_assistant/agent/llms.py index 69f6150..7fc3cfb 100644 --- a/src/geo_assistant/agent/llms.py +++ b/src/geo_assistant/agent/llms.py @@ -1,6 +1,6 @@ import os -from dotenv import load_dotenv +from dotenv import load_dotenv from langchain_ollama import ChatOllama # Load environment variables from env file diff --git a/src/geo_assistant/agent/state.py b/src/geo_assistant/agent/state.py index 40973ed..0177555 100644 --- a/src/geo_assistant/agent/state.py +++ b/src/geo_assistant/agent/state.py @@ -1,6 +1,7 @@ -from langchain.agents import AgentState +from typing import NotRequired + from geojson_pydantic import Feature -from typing_extensions import NotRequired +from langchain.agents import AgentState from pydantic import Field @@ -8,5 +9,6 @@ class GeoAssistantState(AgentState): place: NotRequired[Feature | None] = None search_area: NotRequired[Feature | None] = None naip_png_path: NotRequired[str | None] = Field( - default=None, description="Path to the saved NAIP RGB PNG image" + default=None, + description="Path to the saved NAIP RGB PNG image", ) diff --git a/src/geo_assistant/api/app.py b/src/geo_assistant/api/app.py index 739d932..c8680ea 100644 --- a/src/geo_assistant/api/app.py +++ b/src/geo_assistant/api/app.py @@ -1,11 +1,12 @@ import json +import logging +from collections.abc import AsyncGenerator from contextlib import aclosing, asynccontextmanager -from typing import Any, AsyncGenerator, Dict +from typing import Any from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -import logging from pydantic import UUID4 from geo_assistant.agent.graph import create_graph @@ -43,11 +44,11 @@ async def stream_chat( thread_id: UUID4, chatbot: Any, request: Request, -) -> AsyncGenerator[bytes, None]: - config: Dict[str, Any] = { +) -> AsyncGenerator[bytes]: + config: dict[str, Any] = { "configurable": { "thread_id": str(thread_id), - } + }, } state_updates = {} @@ -69,7 +70,7 @@ async def stream_chat( { "content": f"Manually selected data for field {key}: {description}", "type": "human", - } + }, ) # Add UI messages to the existing messages if they exist diff --git a/src/geo_assistant/api/schemas/chat.py b/src/geo_assistant/api/schemas/chat.py index f757ce1..56927b6 100644 --- a/src/geo_assistant/api/schemas/chat.py +++ b/src/geo_assistant/api/schemas/chat.py @@ -1,4 +1,5 @@ from pydantic import BaseModel + from geo_assistant.agent.state import GeoAssistantState diff --git a/src/geo_assistant/frontend/app.py b/src/geo_assistant/frontend/app.py index b35403e..e0fb44d 100644 --- a/src/geo_assistant/frontend/app.py +++ b/src/geo_assistant/frontend/app.py @@ -2,10 +2,10 @@ import json import os import uuid +import folium import httpx import streamlit as st import streamlit.components.v1 as components -import folium from dotenv import load_dotenv # Load environment variables from .env file @@ -137,7 +137,10 @@ def stream_chat(user_message: str): # Fit map to bounds if we have coordinates if all_lons and all_lats: m.fit_bounds( - [[min(all_lats), min(all_lons)], [max(all_lats), max(all_lons)]] + [ + [min(all_lats), min(all_lons)], + [max(all_lats), max(all_lons)], + ], ) # Display the map diff --git a/src/geo_assistant/tools/buffer.py b/src/geo_assistant/tools/buffer.py index b72d813..8331624 100644 --- a/src/geo_assistant/tools/buffer.py +++ b/src/geo_assistant/tools/buffer.py @@ -1,12 +1,14 @@ +from typing import Annotated + import geopandas as gpd -from langgraph.types import Command -from langgraph.prebuilt import InjectedState -from langchain_core.tools.base import InjectedToolCallId +from geojson_pydantic import Feature from langchain_core.messages import ToolMessage from langchain_core.tools import tool -from typing import Annotated +from langchain_core.tools.base import InjectedToolCallId +from langgraph.prebuilt import InjectedState +from langgraph.types import Command + from geo_assistant.agent.state import GeoAssistantState -from geojson_pydantic import Feature @tool @@ -26,9 +28,9 @@ async def get_search_area( ToolMessage( content="No place defined in the agent state to create a search area around.", tool_call_id=tool_call_id, - ) + ), ], - } + }, ) # Convert GeoJSON feature to GeoDataFrame @@ -38,7 +40,7 @@ async def get_search_area( gdf_m = gdf.to_crs(epsg=3857) # latlon to Web Mercator for meter-based buffering gdf_m["geometry"] = gdf_m["geometry"].buffer( - buffer_size_km * 1000 + buffer_size_km * 1000, ) # Buffer in meters gdf = gdf_m.to_crs(epsg=4326) # Back to WGS84 @@ -46,7 +48,7 @@ async def get_search_area( if len(gdf) != 1: raise ValueError( f"{len(gdf)} features found after buffer operation, should be just 1. " - "Was a Multi-Point/LineString/Polygon geometry passed in?" + "Was a Multi-Point/LineString/Polygon geometry passed in?", ) buffer_feature = Feature( type="Feature", @@ -61,7 +63,7 @@ async def get_search_area( ToolMessage( content=f"Created search area geometry buffer of {buffer_size_km} km around the place.", tool_call_id=tool_call_id, - ) + ), ], - } + }, ) diff --git a/src/geo_assistant/tools/naip.py b/src/geo_assistant/tools/naip.py index 878a65d..c8dc368 100644 --- a/src/geo_assistant/tools/naip.py +++ b/src/geo_assistant/tools/naip.py @@ -1,19 +1,18 @@ # tools/naip_mpc_tools.py -from typing import Dict, Any, Optional, Annotated -from pathlib import Path - from concurrent.futures import ThreadPoolExecutor -import numpy as np -import xarray as xr -import matplotlib.pyplot as plt -from langchain_core.tools import tool -from pystac_client import Client -from odc.stac import stac_load -from langgraph.types import Command -from langchain_core.messages import ToolMessage -from langchain_core.tools.base import InjectedToolCallId +from pathlib import Path +from typing import Annotated, Any import dotenv +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langchain_core.tools.base import InjectedToolCallId +from langgraph.types import Command +from odc.stac import stac_load +from pystac_client import Client dotenv.load_dotenv() @@ -23,10 +22,10 @@ E84_STAC_URL = "https://earth-search.aws.element84.com/v1" @tool("fetch_naip_img") async def fetch_naip_img( - aoi_geojson: Dict[str, Any], + aoi_geojson: dict[str, Any], start_date: str, end_date: str, - tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None, + tool_call_id: Annotated[str | None, InjectedToolCallId] = None, ) -> Command: """ Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and @@ -56,10 +55,10 @@ async def fetch_naip_img( ToolMessage( content="No NAIP imagery found for the specified area and date range.", tool_call_id=tool_call_id, - ) + ), ], "naip_png_path": None, - } + }, ) # --- 2. Load as xarray cube with odc.stac --- @@ -81,14 +80,14 @@ async def fetch_naip_img( ToolMessage( content="Unable to load NAIP RGB image, dataset has no time dimension", tool_call_id=tool_call_id, - ) + ), ], "naip_png_path": None, - } + }, ) # --- 3. Build an RGB composite from the cube --- - # For the PNG, we’ll just use the first time slice (you can swap in “latest” + # For the PNG, we'll just use the first time slice (you can swap in “latest” # or a temporal reduction if you prefer). red = ds["Red"].isel(time=0) green = ds["Green"].isel(time=0) @@ -121,8 +120,8 @@ async def fetch_naip_img( ToolMessage( content=f"NAIP RGB image saved to {out_path.as_posix()}", tool_call_id=tool_call_id, - ) + ), ], "naip_png_path": out_path.as_posix(), - } + }, ) diff --git a/src/geo_assistant/tools/overture.py b/src/geo_assistant/tools/overture.py index 8d2c1ef..8183c4d 100644 --- a/src/geo_assistant/tools/overture.py +++ b/src/geo_assistant/tools/overture.py @@ -4,11 +4,11 @@ from typing import Annotated import duckdb from dotenv import load_dotenv +from geojson_pydantic import Feature from langchain_core.messages import ToolMessage from langchain_core.tools import tool from langchain_core.tools.base import InjectedToolCallId from langgraph.types import Command -from geojson_pydantic import Feature # Load environment variables load_dotenv() @@ -34,7 +34,8 @@ def create_database_connection(): @tool async def get_place( - place_name: str, tool_call_id: Annotated[str, InjectedToolCallId] = "" + place_name: str, + tool_call_id: Annotated[str, InjectedToolCallId] = "", ) -> Command: """Get place location from Overture Maps based on user input place name.""" @@ -65,7 +66,7 @@ async def get_place( WHERE jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) > 0.5 ORDER BY similarity_score DESC LIMIT 1; - """ + """, ).fetchall() db_connection.close() @@ -89,7 +90,7 @@ async def get_place( ToolMessage( content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}", tool_call_id=tool_call_id, - ) + ), ], }, ) diff --git a/src/geo_assistant/tools/summarize.py b/src/geo_assistant/tools/summarize.py index e0b8fe7..70b38fe 100644 --- a/src/geo_assistant/tools/summarize.py +++ b/src/geo_assistant/tools/summarize.py @@ -1,14 +1,14 @@ """Tools for summarizing satellite images using LLM-based analysis.""" import os -from typing import Annotated, Optional -import dspy -from langchain_core.tools import tool -from langgraph.types import Command -from langchain_core.messages import ToolMessage -from langchain_core.tools.base import InjectedToolCallId +from typing import Annotated import dotenv +import dspy +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langchain_core.tools.base import InjectedToolCallId +from langgraph.types import Command dotenv.load_dotenv() @@ -68,7 +68,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent() @tool async def summarize_sat_img( img_url: str, - tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None, + tool_call_id: Annotated[str | None, InjectedToolCallId] = None, ) -> Command: """Summarize the contents of a satellite image using an LLM. @@ -96,7 +96,7 @@ async def summarize_sat_img( content=message_content, artifact=artifact, tool_call_id=tool_call_id, - ) - ] - } + ), + ], + }, ) diff --git a/tests/test_api.py b/tests/test_api.py index d4dc955..a459e14 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,10 +1,11 @@ -import pytest -import pytest_asyncio -from httpx import AsyncClient, ASGITransport from uuid import uuid4 -from geo_assistant.api.app import app +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient + from geo_assistant.agent.graph import create_graph +from geo_assistant.api.app import app @pytest_asyncio.fixture @@ -22,7 +23,8 @@ async def initialized_app(): async def test_hello_world(initialized_app): """Hello world test for the API""" async with AsyncClient( - transport=ASGITransport(app=initialized_app), base_url="http://test" + transport=ASGITransport(app=initialized_app), + base_url="http://test", ) as client: thread_id = uuid4() response = await client.post( @@ -33,7 +35,7 @@ async def test_hello_world(initialized_app): { "content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it", "type": "human", - } + }, ], "place": None, "search_area": None, diff --git a/tests/tools/test_buffer.py b/tests/tools/test_buffer.py index 5fa6910..03c7404 100644 --- a/tests/tools/test_buffer.py +++ b/tests/tools/test_buffer.py @@ -1,8 +1,9 @@ -from pytest import fixture -from geo_assistant.agent.state import GeoAssistantState -from geo_assistant.tools.buffer import get_search_area from geojson_pydantic import Feature, Point from langchain_core.tools.base import ToolCall +from pytest import fixture + +from geo_assistant.agent.state import GeoAssistantState +from geo_assistant.tools.buffer import get_search_area @fixture diff --git a/tests/tools/test_naip.py b/tests/tools/test_naip.py index e9e03fb..e594f4f 100644 --- a/tests/tools/test_naip.py +++ b/tests/tools/test_naip.py @@ -1,7 +1,8 @@ -import pytest from pathlib import Path -from shapely.geometry import box, mapping + +import pytest from langchain_core.tools.base import ToolCall +from shapely.geometry import box, mapping from geo_assistant.tools.naip import fetch_naip_img @@ -19,7 +20,7 @@ async def test_fetch_naip(tmp_path): """ # Union Market coordinates from GeoNames: 38.90789, -76.99831 - # N 38°54′28″ W 76°59′54″ + # N 38°54'28" W 76°59'54" # We'll use a small neighborhood AOI around that point. # :contentReference[oaicite:0]{index=0} lat = 38.90789 @@ -44,7 +45,7 @@ async def test_fetch_naip(tmp_path): id="test_tool_call_id", ) - # Call the actual tool – no STAC / odc-stac mocking + # Call the actual tool - no STAC / odc-stac mocking result = await fetch_naip_img.ainvoke(tool_call["args"]) # Basic sanity checks on result diff --git a/tests/tools/test_overture.py b/tests/tools/test_overture.py index 3196644..264e5a8 100644 --- a/tests/tools/test_overture.py +++ b/tests/tools/test_overture.py @@ -1,4 +1,5 @@ import os + import pytest from langchain_core.tools.base import ToolCall @@ -24,6 +25,6 @@ async def test_get_place(): type="tool_call", id="test_id", args={"place_name": "Neighbourhood Cafe Lisbon"}, - ) + ), ) assert "place" in command.update diff --git a/tests/tools/test_summarize.py b/tests/tools/test_summarize.py index 8390f07..9d629c9 100644 --- a/tests/tools/test_summarize.py +++ b/tests/tools/test_summarize.py @@ -1,9 +1,10 @@ """Tests for the satellite image summarization tool.""" -import pytest import uuid +import pytest from langchain_core.tools.base import ToolCall + from geo_assistant.tools.summarize import summarize_sat_img # Sample test data @@ -24,7 +25,7 @@ def test_summarize_sat_img(img_url, summary): type="tool_call", args={"img_url": img_url}, id=str(uuid.uuid4()), - ) + ), ) print(command.update.get("messages"))