mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Use ruff rules COM, F, I, RUF, UP (#17)
* Use pyupgrade (UP) rule * Use pyflakes (F) rule * Use isort (I) rule * Use ruff-specific (RUF) rules * Use flake8-commas (COM) rule * Fix UP043 Unnecessary default type arguments
This commit is contained in:
@@ -47,3 +47,12 @@ packages = ["src/geo_assistant"]
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--color=yes"
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"COM", # flake8-commas
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"RUF", # ruff-specific
|
||||
"UP", # pyupgrade
|
||||
]
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import datetime
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langchain.agents import create_agent
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from geo_assistant.agent.llms import llm
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from geo_assistant.tools.buffer import get_search_area
|
||||
from geo_assistant.tools.naip import fetch_naip_img
|
||||
from geo_assistant.tools.overture import get_place
|
||||
from geo_assistant.tools.buffer import get_search_area
|
||||
from geo_assistant.tools.summarize import summarize_sat_img
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
@@ -36,7 +37,7 @@ async def create_graph():
|
||||
summarize_sat_img,
|
||||
],
|
||||
system_prompt=SYSTEM_PROMPT.format(
|
||||
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
),
|
||||
state_schema=GeoAssistantState,
|
||||
checkpointer=checkpointer,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
# Load environment variables from env file
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from langchain.agents import AgentState
|
||||
from typing import NotRequired
|
||||
|
||||
from geojson_pydantic import Feature
|
||||
from typing_extensions import NotRequired
|
||||
from langchain.agents import AgentState
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
@@ -8,5 +9,6 @@ class GeoAssistantState(AgentState):
|
||||
place: NotRequired[Feature | None] = None
|
||||
search_area: NotRequired[Feature | None] = None
|
||||
naip_png_path: NotRequired[str | None] = Field(
|
||||
default=None, description="Path to the saved NAIP RGB PNG image"
|
||||
default=None,
|
||||
description="Path to the saved NAIP RGB PNG image",
|
||||
)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import aclosing, asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
import logging
|
||||
from pydantic import UUID4
|
||||
|
||||
from geo_assistant.agent.graph import create_graph
|
||||
@@ -43,11 +44,11 @@ async def stream_chat(
|
||||
thread_id: UUID4,
|
||||
chatbot: Any,
|
||||
request: Request,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
config: Dict[str, Any] = {
|
||||
) -> AsyncGenerator[bytes]:
|
||||
config: dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": str(thread_id),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
state_updates = {}
|
||||
@@ -69,7 +70,7 @@ async def stream_chat(
|
||||
{
|
||||
"content": f"Manually selected data for field {key}: {description}",
|
||||
"type": "human",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Add UI messages to the existing messages if they exist
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import folium
|
||||
import httpx
|
||||
import streamlit as st
|
||||
import streamlit.components.v1 as components
|
||||
import folium
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
@@ -137,7 +137,10 @@ def stream_chat(user_message: str):
|
||||
# Fit map to bounds if we have coordinates
|
||||
if all_lons and all_lats:
|
||||
m.fit_bounds(
|
||||
[[min(all_lats), min(all_lons)], [max(all_lats), max(all_lons)]]
|
||||
[
|
||||
[min(all_lats), min(all_lons)],
|
||||
[max(all_lats), max(all_lons)],
|
||||
],
|
||||
)
|
||||
|
||||
# Display the map
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from typing import Annotated
|
||||
|
||||
import geopandas as gpd
|
||||
from langgraph.types import Command
|
||||
from langgraph.prebuilt import InjectedState
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from geojson_pydantic import Feature
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from langgraph.prebuilt import InjectedState
|
||||
from langgraph.types import Command
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from geojson_pydantic import Feature
|
||||
|
||||
|
||||
@tool
|
||||
@@ -26,9 +28,9 @@ async def get_search_area(
|
||||
ToolMessage(
|
||||
content="No place defined in the agent state to create a search area around.",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Convert GeoJSON feature to GeoDataFrame
|
||||
@@ -38,7 +40,7 @@ async def get_search_area(
|
||||
gdf_m = gdf.to_crs(epsg=3857) # latlon to Web Mercator for meter-based buffering
|
||||
|
||||
gdf_m["geometry"] = gdf_m["geometry"].buffer(
|
||||
buffer_size_km * 1000
|
||||
buffer_size_km * 1000,
|
||||
) # Buffer in meters
|
||||
gdf = gdf_m.to_crs(epsg=4326) # Back to WGS84
|
||||
|
||||
@@ -46,7 +48,7 @@ async def get_search_area(
|
||||
if len(gdf) != 1:
|
||||
raise ValueError(
|
||||
f"{len(gdf)} features found after buffer operation, should be just 1. "
|
||||
"Was a Multi-Point/LineString/Polygon geometry passed in?"
|
||||
"Was a Multi-Point/LineString/Polygon geometry passed in?",
|
||||
)
|
||||
buffer_feature = Feature(
|
||||
type="Feature",
|
||||
@@ -61,7 +63,7 @@ async def get_search_area(
|
||||
ToolMessage(
|
||||
content=f"Created search area geometry buffer of {buffer_size_km} km around the place.",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
# tools/naip_mpc_tools.py
|
||||
from typing import Dict, Any, Optional, Annotated
|
||||
from pathlib import Path
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
import matplotlib.pyplot as plt
|
||||
from langchain_core.tools import tool
|
||||
from pystac_client import Client
|
||||
from odc.stac import stac_load
|
||||
from langgraph.types import Command
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
import dotenv
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from langgraph.types import Command
|
||||
from odc.stac import stac_load
|
||||
from pystac_client import Client
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@@ -23,10 +22,10 @@ E84_STAC_URL = "https://earth-search.aws.element84.com/v1"
|
||||
|
||||
@tool("fetch_naip_img")
|
||||
async def fetch_naip_img(
|
||||
aoi_geojson: Dict[str, Any],
|
||||
aoi_geojson: dict[str, Any],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
|
||||
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||
) -> Command:
|
||||
"""
|
||||
Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and
|
||||
@@ -56,10 +55,10 @@ async def fetch_naip_img(
|
||||
ToolMessage(
|
||||
content="No NAIP imagery found for the specified area and date range.",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
"naip_png_path": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# --- 2. Load as xarray cube with odc.stac ---
|
||||
@@ -81,14 +80,14 @@ async def fetch_naip_img(
|
||||
ToolMessage(
|
||||
content="Unable to load NAIP RGB image, dataset has no time dimension",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
"naip_png_path": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# --- 3. Build an RGB composite from the cube ---
|
||||
# For the PNG, we’ll just use the first time slice (you can swap in “latest”
|
||||
# For the PNG, we'll just use the first time slice (you can swap in “latest”
|
||||
# or a temporal reduction if you prefer).
|
||||
red = ds["Red"].isel(time=0)
|
||||
green = ds["Green"].isel(time=0)
|
||||
@@ -121,8 +120,8 @@ async def fetch_naip_img(
|
||||
ToolMessage(
|
||||
content=f"NAIP RGB image saved to {out_path.as_posix()}",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
"naip_png_path": out_path.as_posix(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -4,11 +4,11 @@ from typing import Annotated
|
||||
|
||||
import duckdb
|
||||
from dotenv import load_dotenv
|
||||
from geojson_pydantic import Feature
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from langgraph.types import Command
|
||||
from geojson_pydantic import Feature
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
@@ -34,7 +34,8 @@ def create_database_connection():
|
||||
|
||||
@tool
|
||||
async def get_place(
|
||||
place_name: str, tool_call_id: Annotated[str, InjectedToolCallId] = ""
|
||||
place_name: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||
) -> Command:
|
||||
"""Get place location from Overture Maps based on user input place name."""
|
||||
|
||||
@@ -65,7 +66,7 @@ async def get_place(
|
||||
WHERE jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) > 0.5
|
||||
ORDER BY similarity_score DESC
|
||||
LIMIT 1;
|
||||
"""
|
||||
""",
|
||||
).fetchall()
|
||||
|
||||
db_connection.close()
|
||||
@@ -89,7 +90,7 @@ async def get_place(
|
||||
ToolMessage(
|
||||
content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Tools for summarizing satellite images using LLM-based analysis."""
|
||||
|
||||
import os
|
||||
from typing import Annotated, Optional
|
||||
import dspy
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from typing import Annotated
|
||||
|
||||
import dotenv
|
||||
import dspy
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from langgraph.types import Command
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@@ -68,7 +68,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent()
|
||||
@tool
|
||||
async def summarize_sat_img(
|
||||
img_url: str,
|
||||
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
|
||||
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||
) -> Command:
|
||||
"""Summarize the contents of a satellite image using an LLM.
|
||||
|
||||
@@ -96,7 +96,7 @@ async def summarize_sat_img(
|
||||
content=message_content,
|
||||
artifact=artifact,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
]
|
||||
}
|
||||
),
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
+8
-6
@@ -1,10 +1,11 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from uuid import uuid4
|
||||
|
||||
from geo_assistant.api.app import app
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from geo_assistant.agent.graph import create_graph
|
||||
from geo_assistant.api.app import app
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -22,7 +23,8 @@ async def initialized_app():
|
||||
async def test_hello_world(initialized_app):
|
||||
"""Hello world test for the API"""
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=initialized_app), base_url="http://test"
|
||||
transport=ASGITransport(app=initialized_app),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
thread_id = uuid4()
|
||||
response = await client.post(
|
||||
@@ -33,7 +35,7 @@ async def test_hello_world(initialized_app):
|
||||
{
|
||||
"content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it",
|
||||
"type": "human",
|
||||
}
|
||||
},
|
||||
],
|
||||
"place": None,
|
||||
"search_area": None,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from pytest import fixture
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from geo_assistant.tools.buffer import get_search_area
|
||||
from geojson_pydantic import Feature, Point
|
||||
from langchain_core.tools.base import ToolCall
|
||||
from pytest import fixture
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from geo_assistant.tools.buffer import get_search_area
|
||||
|
||||
|
||||
@fixture
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from shapely.geometry import box, mapping
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools.base import ToolCall
|
||||
from shapely.geometry import box, mapping
|
||||
|
||||
from geo_assistant.tools.naip import fetch_naip_img
|
||||
|
||||
@@ -19,7 +20,7 @@ async def test_fetch_naip(tmp_path):
|
||||
"""
|
||||
|
||||
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
||||
# N 38°54′28″ W 76°59′54″
|
||||
# N 38°54'28" W 76°59'54"
|
||||
# We'll use a small neighborhood AOI around that point.
|
||||
# :contentReference[oaicite:0]{index=0}
|
||||
lat = 38.90789
|
||||
@@ -44,7 +45,7 @@ async def test_fetch_naip(tmp_path):
|
||||
id="test_tool_call_id",
|
||||
)
|
||||
|
||||
# Call the actual tool – no STAC / odc-stac mocking
|
||||
# Call the actual tool - no STAC / odc-stac mocking
|
||||
result = await fetch_naip_img.ainvoke(tool_call["args"])
|
||||
|
||||
# Basic sanity checks on result
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools.base import ToolCall
|
||||
|
||||
@@ -24,6 +25,6 @@ async def test_get_place():
|
||||
type="tool_call",
|
||||
id="test_id",
|
||||
args={"place_name": "Neighbourhood Cafe Lisbon"},
|
||||
)
|
||||
),
|
||||
)
|
||||
assert "place" in command.update
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Tests for the satellite image summarization tool."""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools.base import ToolCall
|
||||
|
||||
from geo_assistant.tools.summarize import summarize_sat_img
|
||||
|
||||
# Sample test data
|
||||
@@ -24,7 +25,7 @@ def test_summarize_sat_img(img_url, summary):
|
||||
type="tool_call",
|
||||
args={"img_url": img_url},
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
print(command.update.get("messages"))
|
||||
|
||||
Reference in New Issue
Block a user