mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Enable pydocstyle (D) ruff rule and add more docs (#21)
* Use pydocstyle (D) rule with google convention Add a ruff rule to catch missing documentation. Using google convention so that undocumented-param (D417) rule is enabled to catch missing params, xref https://docs.astral.sh/ruff/rules/undocumented-param. Extended to include D213 (instead of D212) and D410 rules too. * Fix D100 Missing docstring in public module * Fix D101 Missing docstring in public class * Fix D103 Missing docstring in public function Also ignore rule D205 to allow first sentence of docstring to wrap to multiple lines. * Fix D417 Missing argument description in the docstring * Update indent in pyproject.toml file
This commit is contained in:
@@ -59,8 +59,22 @@ asyncio_mode = "auto"
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"COM", # flake8-commas
|
||||
"D", # pydocstyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"RUF", # ruff-specific
|
||||
"UP", # pyupgrade
|
||||
]
|
||||
extend-select = [
|
||||
"D213", # Summary lines should be positioned on the second physical line of the docstring.
|
||||
"D410", # A blank line after section headings.
|
||||
]
|
||||
ignore = [
|
||||
"D205", # 1 blank line required between summary line and description
|
||||
"D212", # Multi-line docstring summary should start at the first line
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
# See https://docs.astral.sh/ruff/faq/#does-ruff-support-numpy-or-google-style-docstrings
|
||||
# for the enabled/disabled rules for the "google" convention.
|
||||
convention = "google"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Create agent graph that calls tools."""
|
||||
|
||||
import datetime
|
||||
|
||||
from langchain.agents import create_agent
|
||||
@@ -31,6 +33,7 @@ The current date and time is {now}.
|
||||
|
||||
|
||||
async def create_graph():
|
||||
"""Create langchain agent graph with a list of tools."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = create_agent(
|
||||
model=llm,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Ollama chat model."""
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""State schema for the geo-assistant agent."""
|
||||
|
||||
from typing import NotRequired
|
||||
|
||||
from geojson_pydantic import Feature, FeatureCollection
|
||||
@@ -6,6 +8,8 @@ from pydantic import Field
|
||||
|
||||
|
||||
class GeoAssistantState(AgentState):
|
||||
"""Schema for the geo-assistant agent's state."""
|
||||
|
||||
place: NotRequired[Feature | None] = None
|
||||
search_area: NotRequired[Feature | None] = None
|
||||
places_within_buffer: NotRequired[FeatureCollection | None] = None
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Chat app API endpoint."""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import aclosing, asynccontextmanager
|
||||
@@ -21,12 +23,12 @@ UI_SET_FIELDS_WHITELIST = ["point", "messages"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def _lifespan(app: FastAPI):
|
||||
app.state.chatbot = await create_graph()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Geo Assistant", lifespan=lifespan)
|
||||
app = FastAPI(title="Geo Assistant", lifespan=_lifespan)
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
@@ -44,6 +46,7 @@ async def stream_chat(
|
||||
chatbot: Any,
|
||||
request: Request,
|
||||
) -> AsyncGenerator[bytes]:
|
||||
"""Agent chat stream."""
|
||||
config: dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": str(thread_id),
|
||||
@@ -101,6 +104,7 @@ async def stream_chat(
|
||||
|
||||
@app.post("/chat")
|
||||
async def chat(request: ChatRequestBody, http_request: Request) -> StreamingResponse:
|
||||
"""HTTP POST endpoint at /chat."""
|
||||
generator = stream_chat(
|
||||
ui_state_update=request.agent_state_input,
|
||||
thread_id=request.thread_id,
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
"""Chat API schemas."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
"""Schema for the request to the Chat API."""
|
||||
|
||||
thread_id: str
|
||||
agent_state_input: GeoAssistantState
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Schema for the response from the Chat API."""
|
||||
|
||||
thread_id: str
|
||||
state: GeoAssistantState
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Chat app frontend."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""List of tools available to the agent."""
|
||||
|
||||
from geo_assistant.tools.buffer import get_search_area
|
||||
from geo_assistant.tools.naip import fetch_naip_img
|
||||
from geo_assistant.tools.overture import get_place, get_places_within_buffer
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Tool to create a buffer polygon around a geometry feature."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
import geopandas as gpd
|
||||
@@ -17,8 +19,14 @@ async def get_search_area(
|
||||
state: Annotated[GeoAssistantState, InjectedState],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||
) -> Command:
|
||||
"""Get a search area buffer in km around the place defined in the agent state."""
|
||||
"""
|
||||
Get a search area buffer in km around the place defined in the agent state.
|
||||
|
||||
Args:
|
||||
buffer_size_km: Radius of the buffer in kilometres.
|
||||
state: Pass in 'place' as state into this agent.
|
||||
tool_call_id: Optional ID for tracking the tool call.
|
||||
"""
|
||||
place_feature = state.get("place")
|
||||
|
||||
if not place_feature:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# tools/naip_mpc_tools.py
|
||||
"""Tool to query Planetary Computer STAC API for NAIP imagery."""
|
||||
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
@@ -39,7 +40,8 @@ async def fetch_naip_img(
|
||||
Args:
|
||||
start_date: Start date (YYYY-MM-DD).
|
||||
end_date: End date (YYYY-MM-DD).
|
||||
|
||||
state: Pass in search_area as state into this agent.
|
||||
tool_call_id: Optional ID for tracking the tool call
|
||||
"""
|
||||
if not state["search_area"]:
|
||||
return Command(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Tool to find closest matching Overture place based on user input."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated
|
||||
@@ -21,7 +23,8 @@ load_dotenv()
|
||||
|
||||
|
||||
def create_database_connection():
|
||||
"""Create and configure a DuckDB connection with necessary extensions.
|
||||
"""
|
||||
Create and configure a DuckDB connection with necessary extensions.
|
||||
|
||||
Args:
|
||||
database_path: Path to the DuckDB database file
|
||||
@@ -43,8 +46,14 @@ async def get_place(
|
||||
place_name: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||
) -> Command:
|
||||
"""Get place location from Overture Maps based on user input place name."""
|
||||
"""
|
||||
Get place location from Overture Maps based on user input place name.
|
||||
|
||||
Args:
|
||||
place_name: An address or location given as a human-readable string.
|
||||
tool_call_id: Optional ID for tracking the tool call.
|
||||
|
||||
"""
|
||||
db_connection = create_database_connection()
|
||||
source = os.getenv("OVERTURE_SOURCE", "local")
|
||||
if source == "s3":
|
||||
@@ -162,10 +171,16 @@ async def get_places_within_buffer(
|
||||
state: Annotated[GeoAssistantState, InjectedState],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
"""Get places from Overture Maps within user specified area and user specified Overture place type.
|
||||
|
||||
Accepts: restaurant(s), cafe(s), coffee shop(s), bar(s), pub(s) - case insensitive."""
|
||||
"""
|
||||
Get places from Overture Maps within user specified area and user specified Overture
|
||||
place type.
|
||||
|
||||
Args:
|
||||
place: Overture place type. Accepts: restaurant(s), cafe(s), coffee shop(s),
|
||||
bar(s), pub(s) - case insensitive.
|
||||
state: Pass in 'search_area' as state into this agent.
|
||||
tool_call_id: Optional ID for tracking the tool call.
|
||||
"""
|
||||
# Normalize the place type
|
||||
place = normalize_place_type(place)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ dotenv.load_dotenv()
|
||||
|
||||
|
||||
class SatImgSummary(dspy.Signature):
|
||||
"Describe things you see in the satellite image."
|
||||
"""Describe things you see in the satellite image."""
|
||||
|
||||
img: dspy.Image = dspy.InputField(desc="A satellite image")
|
||||
answer: str = dspy.OutputField(desc="Description of the image")
|
||||
@@ -33,7 +33,8 @@ class SatImgSummaryAgent(dspy.Module):
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 4_096,
|
||||
) -> None:
|
||||
"""Initialize the satellite image summary agent.
|
||||
"""
|
||||
Initialize the satellite image summary agent.
|
||||
|
||||
Args:
|
||||
model: The Ollama model to use for summarization
|
||||
@@ -53,7 +54,8 @@ class SatImgSummaryAgent(dspy.Module):
|
||||
self.summarizer = dspy.Predict(SatImgSummary)
|
||||
|
||||
def forward(self, img_url: str) -> dspy.Prediction:
|
||||
"""Generate a summary for the given image URL.
|
||||
"""
|
||||
Generate a summary for the given image URL.
|
||||
|
||||
Args:
|
||||
img_url: URL of the image to summarize
|
||||
@@ -73,7 +75,12 @@ async def summarize_sat_img(
|
||||
state: Annotated[GeoAssistantState, InjectedState],
|
||||
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||
) -> Command:
|
||||
"""Summarize the contents of a satellite image using an LLM.
|
||||
"""
|
||||
Summarize the contents of a satellite image using an LLM.
|
||||
|
||||
Args:
|
||||
state: Pass in 'naip_img_bytes' as state into this agent.
|
||||
tool_call_id: Optional ID for tracking the tool call.
|
||||
|
||||
Returns:
|
||||
Command containing the image summary and metadata
|
||||
|
||||
+4
-1
@@ -1,3 +1,5 @@
|
||||
"""Tests for chat API endpoint."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@@ -10,7 +12,7 @@ from geo_assistant.api.app import app
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def initialized_app():
|
||||
"""Initialize the app's chatbot before testing"""
|
||||
"""Initialize the app's chatbot before testing."""
|
||||
# Manually initialize the chatbot as the lifespan would
|
||||
app.state.chatbot = await create_graph()
|
||||
yield app
|
||||
@@ -21,6 +23,7 @@ async def initialized_app():
|
||||
|
||||
@pytest.mark.xfail
|
||||
async def test_call_api(initialized_app):
|
||||
"""Test calling the API at the /chat HTTP POST endpoint."""
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=initialized_app),
|
||||
base_url="http://test",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Tests for buffer tool."""
|
||||
|
||||
from geojson_pydantic import Feature, Point
|
||||
from langchain_core.tools.base import ToolCall
|
||||
from pytest import fixture
|
||||
@@ -8,6 +10,7 @@ from geo_assistant.tools.buffer import get_search_area
|
||||
|
||||
@fixture
|
||||
def geo_assistant_fixture():
|
||||
"""Fixture with a GeoJSON point feature in a GeoAssistantState."""
|
||||
place_geojson = Feature(
|
||||
type="Feature",
|
||||
geometry=Point(type="Point", coordinates=[-9.1393, 38.7223]),
|
||||
@@ -22,6 +25,7 @@ def geo_assistant_fixture():
|
||||
|
||||
|
||||
async def test_get_search_area(geo_assistant_fixture):
|
||||
"""Ensure that `get_search_area` tool returns a buffer Polygon."""
|
||||
# Call the underlying function directly to test the logic
|
||||
# This bypasses the injection framework which is better suited for integration tests
|
||||
command = await get_search_area.ainvoke(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Tests for NAIP tool."""
|
||||
|
||||
from types import NoneType
|
||||
|
||||
import pytest
|
||||
@@ -19,7 +21,6 @@ async def test_fetch_naip():
|
||||
- Internet access (to reach Planetary Computer STAC + blobs)
|
||||
- Planetary Computer / NAIP service to be up
|
||||
"""
|
||||
|
||||
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
||||
# N 38°54'28" W 76°59'54"
|
||||
# We'll use a small neighborhood AOI around that point.
|
||||
@@ -60,7 +61,6 @@ async def test_fetch_naip_too_large():
|
||||
- Internet access (to reach Planetary Computer STAC + blobs)
|
||||
- Planetary Computer / NAIP service to be up
|
||||
"""
|
||||
|
||||
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
||||
# N 38°54'28" W 76°59'54"
|
||||
# We'll use a small neighborhood AOI around that point.
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Tests for Overture tool."""
|
||||
|
||||
import os
|
||||
|
||||
import geopandas as gpd
|
||||
@@ -57,6 +59,7 @@ def geo_assistant_with_buffer_fixture():
|
||||
|
||||
|
||||
async def test_get_place():
|
||||
"""Ensure that `get_place` tool returns an Overture place given a place_name."""
|
||||
command = await get_place.ainvoke(
|
||||
ToolCall(
|
||||
name="get_place",
|
||||
@@ -69,6 +72,10 @@ async def test_get_place():
|
||||
|
||||
|
||||
async def test_get_places_within_buffer(geo_assistant_with_buffer_fixture):
|
||||
"""
|
||||
Ensure that `get_places_within_buffer` tool returns multiple Overture places that
|
||||
fit match the category 'cafe' within a specific buffer area around a location.
|
||||
"""
|
||||
command = await get_places_within_buffer.ainvoke(
|
||||
ToolCall(
|
||||
name="get_places_within_buffer",
|
||||
|
||||
@@ -23,6 +23,10 @@ TEST_IMAGE_URL = "https://petapixel.com/assets/uploads/2022/08/French-Officials-
|
||||
],
|
||||
)
|
||||
async def test_summarize_sat_img(img_url, summary):
|
||||
"""
|
||||
Ensure that the `summarize_sat_img` tool can describe a satellite image in JPEG
|
||||
format.
|
||||
"""
|
||||
# Load the image from the supplied URL and encode it in base64
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
||||
|
||||
Reference in New Issue
Block a user