mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Enable pydocstyle (D) ruff rule and add more docs (#21)
* Use pydocstyle (D) rule with google convention Add a ruff rule to catch missing documentation. Using google convention so that undocumented-param (D417) rule is enabled to catch missing params, xref https://docs.astral.sh/ruff/rules/undocumented-param. Extended to include D213 (instead of D212) and D410 rules too. * Fix D100 Missing docstring in public module * Fix D101 Missing docstring in public class * Fix D103 Missing docstring in public function Also ignore rule D205 to allow first sentence of docstring to wrap to multiple lines. * Fix D417 Missing argument description in the docstring * Update indent in pyproject.toml file
This commit is contained in:
@@ -59,8 +59,22 @@ asyncio_mode = "auto"
|
|||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
|
"D", # pydocstyle
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
"I", # isort
|
"I", # isort
|
||||||
"RUF", # ruff-specific
|
"RUF", # ruff-specific
|
||||||
"UP", # pyupgrade
|
"UP", # pyupgrade
|
||||||
]
|
]
|
||||||
|
extend-select = [
|
||||||
|
"D213", # Summary lines should be positioned on the second physical line of the docstring.
|
||||||
|
"D410", # A blank line after section headings.
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"D205", # 1 blank line required between summary line and description
|
||||||
|
"D212", # Multi-line docstring summary should start at the first line
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.pydocstyle]
|
||||||
|
# See https://docs.astral.sh/ruff/faq/#does-ruff-support-numpy-or-google-style-docstrings
|
||||||
|
# for the enabled/disabled rules for the "google" convention.
|
||||||
|
convention = "google"
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Create agent graph that calls tools."""
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
@@ -31,6 +33,7 @@ The current date and time is {now}.
|
|||||||
|
|
||||||
|
|
||||||
async def create_graph():
|
async def create_graph():
|
||||||
|
"""Create langchain agent graph with a list of tools."""
|
||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
graph = create_agent(
|
graph = create_agent(
|
||||||
model=llm,
|
model=llm,
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Ollama chat model."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""State schema for the geo-assistant agent."""
|
||||||
|
|
||||||
from typing import NotRequired
|
from typing import NotRequired
|
||||||
|
|
||||||
from geojson_pydantic import Feature, FeatureCollection
|
from geojson_pydantic import Feature, FeatureCollection
|
||||||
@@ -6,6 +8,8 @@ from pydantic import Field
|
|||||||
|
|
||||||
|
|
||||||
class GeoAssistantState(AgentState):
|
class GeoAssistantState(AgentState):
|
||||||
|
"""Schema for the geo-assistant agent's state."""
|
||||||
|
|
||||||
place: NotRequired[Feature | None] = None
|
place: NotRequired[Feature | None] = None
|
||||||
search_area: NotRequired[Feature | None] = None
|
search_area: NotRequired[Feature | None] = None
|
||||||
places_within_buffer: NotRequired[FeatureCollection | None] = None
|
places_within_buffer: NotRequired[FeatureCollection | None] = None
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Chat app API endpoint."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import aclosing, asynccontextmanager
|
from contextlib import aclosing, asynccontextmanager
|
||||||
@@ -21,12 +23,12 @@ UI_SET_FIELDS_WHITELIST = ["point", "messages"]
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def _lifespan(app: FastAPI):
|
||||||
app.state.chatbot = await create_graph()
|
app.state.chatbot = await create_graph()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Geo Assistant", lifespan=lifespan)
|
app = FastAPI(title="Geo Assistant", lifespan=_lifespan)
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@@ -44,6 +46,7 @@ async def stream_chat(
|
|||||||
chatbot: Any,
|
chatbot: Any,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> AsyncGenerator[bytes]:
|
) -> AsyncGenerator[bytes]:
|
||||||
|
"""Agent chat stream."""
|
||||||
config: dict[str, Any] = {
|
config: dict[str, Any] = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
"thread_id": str(thread_id),
|
"thread_id": str(thread_id),
|
||||||
@@ -101,6 +104,7 @@ async def stream_chat(
|
|||||||
|
|
||||||
@app.post("/chat")
|
@app.post("/chat")
|
||||||
async def chat(request: ChatRequestBody, http_request: Request) -> StreamingResponse:
|
async def chat(request: ChatRequestBody, http_request: Request) -> StreamingResponse:
|
||||||
|
"""HTTP POST endpoint at /chat."""
|
||||||
generator = stream_chat(
|
generator = stream_chat(
|
||||||
ui_state_update=request.agent_state_input,
|
ui_state_update=request.agent_state_input,
|
||||||
thread_id=request.thread_id,
|
thread_id=request.thread_id,
|
||||||
|
|||||||
@@ -1,13 +1,19 @@
|
|||||||
|
"""Chat API schemas."""
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from geo_assistant.agent.state import GeoAssistantState
|
from geo_assistant.agent.state import GeoAssistantState
|
||||||
|
|
||||||
|
|
||||||
class ChatRequestBody(BaseModel):
|
class ChatRequestBody(BaseModel):
|
||||||
|
"""Schema for the request to the Chat API."""
|
||||||
|
|
||||||
thread_id: str
|
thread_id: str
|
||||||
agent_state_input: GeoAssistantState
|
agent_state_input: GeoAssistantState
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
|
"""Schema for the response from the Chat API."""
|
||||||
|
|
||||||
thread_id: str
|
thread_id: str
|
||||||
state: GeoAssistantState
|
state: GeoAssistantState
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Chat app frontend."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""List of tools available to the agent."""
|
||||||
|
|
||||||
from geo_assistant.tools.buffer import get_search_area
|
from geo_assistant.tools.buffer import get_search_area
|
||||||
from geo_assistant.tools.naip import fetch_naip_img
|
from geo_assistant.tools.naip import fetch_naip_img
|
||||||
from geo_assistant.tools.overture import get_place, get_places_within_buffer
|
from geo_assistant.tools.overture import get_place, get_places_within_buffer
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Tool to create a buffer polygon around a geometry feature."""
|
||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
@@ -17,8 +19,14 @@ async def get_search_area(
|
|||||||
state: Annotated[GeoAssistantState, InjectedState],
|
state: Annotated[GeoAssistantState, InjectedState],
|
||||||
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||||
) -> Command:
|
) -> Command:
|
||||||
"""Get a search area buffer in km around the place defined in the agent state."""
|
"""
|
||||||
|
Get a search area buffer in km around the place defined in the agent state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer_size_km: Radius of the buffer in kilometres.
|
||||||
|
state: Pass in 'place' as state into this agent.
|
||||||
|
tool_call_id: Optional ID for tracking the tool call.
|
||||||
|
"""
|
||||||
place_feature = state.get("place")
|
place_feature = state.get("place")
|
||||||
|
|
||||||
if not place_feature:
|
if not place_feature:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# tools/naip_mpc_tools.py
|
"""Tool to query Planetary Computer STAC API for NAIP imagery."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -39,7 +40,8 @@ async def fetch_naip_img(
|
|||||||
Args:
|
Args:
|
||||||
start_date: Start date (YYYY-MM-DD).
|
start_date: Start date (YYYY-MM-DD).
|
||||||
end_date: End date (YYYY-MM-DD).
|
end_date: End date (YYYY-MM-DD).
|
||||||
|
state: Pass in search_area as state into this agent.
|
||||||
|
tool_call_id: Optional ID for tracking the tool call
|
||||||
"""
|
"""
|
||||||
if not state["search_area"]:
|
if not state["search_area"]:
|
||||||
return Command(
|
return Command(
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Tool to find closest matching Overture place based on user input."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
@@ -21,7 +23,8 @@ load_dotenv()
|
|||||||
|
|
||||||
|
|
||||||
def create_database_connection():
|
def create_database_connection():
|
||||||
"""Create and configure a DuckDB connection with necessary extensions.
|
"""
|
||||||
|
Create and configure a DuckDB connection with necessary extensions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
database_path: Path to the DuckDB database file
|
database_path: Path to the DuckDB database file
|
||||||
@@ -43,8 +46,14 @@ async def get_place(
|
|||||||
place_name: str,
|
place_name: str,
|
||||||
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||||
) -> Command:
|
) -> Command:
|
||||||
"""Get place location from Overture Maps based on user input place name."""
|
"""
|
||||||
|
Get place location from Overture Maps based on user input place name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
place_name: An address or location given as a human-readable string.
|
||||||
|
tool_call_id: Optional ID for tracking the tool call.
|
||||||
|
|
||||||
|
"""
|
||||||
db_connection = create_database_connection()
|
db_connection = create_database_connection()
|
||||||
source = os.getenv("OVERTURE_SOURCE", "local")
|
source = os.getenv("OVERTURE_SOURCE", "local")
|
||||||
if source == "s3":
|
if source == "s3":
|
||||||
@@ -162,10 +171,16 @@ async def get_places_within_buffer(
|
|||||||
state: Annotated[GeoAssistantState, InjectedState],
|
state: Annotated[GeoAssistantState, InjectedState],
|
||||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||||
) -> Command:
|
) -> Command:
|
||||||
"""Get places from Overture Maps within user specified area and user specified Overture place type.
|
"""
|
||||||
|
Get places from Overture Maps within user specified area and user specified Overture
|
||||||
Accepts: restaurant(s), cafe(s), coffee shop(s), bar(s), pub(s) - case insensitive."""
|
place type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
place: Overture place type. Accepts: restaurant(s), cafe(s), coffee shop(s),
|
||||||
|
bar(s), pub(s) - case insensitive.
|
||||||
|
state: Pass in 'search_area' as state into this agent.
|
||||||
|
tool_call_id: Optional ID for tracking the tool call.
|
||||||
|
"""
|
||||||
# Normalize the place type
|
# Normalize the place type
|
||||||
place = normalize_place_type(place)
|
place = normalize_place_type(place)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ dotenv.load_dotenv()
|
|||||||
|
|
||||||
|
|
||||||
class SatImgSummary(dspy.Signature):
|
class SatImgSummary(dspy.Signature):
|
||||||
"Describe things you see in the satellite image."
|
"""Describe things you see in the satellite image."""
|
||||||
|
|
||||||
img: dspy.Image = dspy.InputField(desc="A satellite image")
|
img: dspy.Image = dspy.InputField(desc="A satellite image")
|
||||||
answer: str = dspy.OutputField(desc="Description of the image")
|
answer: str = dspy.OutputField(desc="Description of the image")
|
||||||
@@ -33,7 +33,8 @@ class SatImgSummaryAgent(dspy.Module):
|
|||||||
temperature: float = 0.5,
|
temperature: float = 0.5,
|
||||||
max_tokens: int = 4_096,
|
max_tokens: int = 4_096,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the satellite image summary agent.
|
"""
|
||||||
|
Initialize the satellite image summary agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The Ollama model to use for summarization
|
model: The Ollama model to use for summarization
|
||||||
@@ -53,7 +54,8 @@ class SatImgSummaryAgent(dspy.Module):
|
|||||||
self.summarizer = dspy.Predict(SatImgSummary)
|
self.summarizer = dspy.Predict(SatImgSummary)
|
||||||
|
|
||||||
def forward(self, img_url: str) -> dspy.Prediction:
|
def forward(self, img_url: str) -> dspy.Prediction:
|
||||||
"""Generate a summary for the given image URL.
|
"""
|
||||||
|
Generate a summary for the given image URL.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_url: URL of the image to summarize
|
img_url: URL of the image to summarize
|
||||||
@@ -73,7 +75,12 @@ async def summarize_sat_img(
|
|||||||
state: Annotated[GeoAssistantState, InjectedState],
|
state: Annotated[GeoAssistantState, InjectedState],
|
||||||
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||||
) -> Command:
|
) -> Command:
|
||||||
"""Summarize the contents of a satellite image using an LLM.
|
"""
|
||||||
|
Summarize the contents of a satellite image using an LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Pass in 'naip_img_bytes' as state into this agent.
|
||||||
|
tool_call_id: Optional ID for tracking the tool call.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Command containing the image summary and metadata
|
Command containing the image summary and metadata
|
||||||
|
|||||||
+4
-1
@@ -1,3 +1,5 @@
|
|||||||
|
"""Tests for chat API endpoint."""
|
||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -10,7 +12,7 @@ from geo_assistant.api.app import app
|
|||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def initialized_app():
|
async def initialized_app():
|
||||||
"""Initialize the app's chatbot before testing"""
|
"""Initialize the app's chatbot before testing."""
|
||||||
# Manually initialize the chatbot as the lifespan would
|
# Manually initialize the chatbot as the lifespan would
|
||||||
app.state.chatbot = await create_graph()
|
app.state.chatbot = await create_graph()
|
||||||
yield app
|
yield app
|
||||||
@@ -21,6 +23,7 @@ async def initialized_app():
|
|||||||
|
|
||||||
@pytest.mark.xfail
|
@pytest.mark.xfail
|
||||||
async def test_call_api(initialized_app):
|
async def test_call_api(initialized_app):
|
||||||
|
"""Test calling the API at the /chat HTTP POST endpoint."""
|
||||||
async with AsyncClient(
|
async with AsyncClient(
|
||||||
transport=ASGITransport(app=initialized_app),
|
transport=ASGITransport(app=initialized_app),
|
||||||
base_url="http://test",
|
base_url="http://test",
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Tests for buffer tool."""
|
||||||
|
|
||||||
from geojson_pydantic import Feature, Point
|
from geojson_pydantic import Feature, Point
|
||||||
from langchain_core.tools.base import ToolCall
|
from langchain_core.tools.base import ToolCall
|
||||||
from pytest import fixture
|
from pytest import fixture
|
||||||
@@ -8,6 +10,7 @@ from geo_assistant.tools.buffer import get_search_area
|
|||||||
|
|
||||||
@fixture
|
@fixture
|
||||||
def geo_assistant_fixture():
|
def geo_assistant_fixture():
|
||||||
|
"""Fixture with a GeoJSON point feature in a GeoAssistantState."""
|
||||||
place_geojson = Feature(
|
place_geojson = Feature(
|
||||||
type="Feature",
|
type="Feature",
|
||||||
geometry=Point(type="Point", coordinates=[-9.1393, 38.7223]),
|
geometry=Point(type="Point", coordinates=[-9.1393, 38.7223]),
|
||||||
@@ -22,6 +25,7 @@ def geo_assistant_fixture():
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_search_area(geo_assistant_fixture):
|
async def test_get_search_area(geo_assistant_fixture):
|
||||||
|
"""Ensure that `get_search_area` tool returns a buffer Polygon."""
|
||||||
# Call the underlying function directly to test the logic
|
# Call the underlying function directly to test the logic
|
||||||
# This bypasses the injection framework which is better suited for integration tests
|
# This bypasses the injection framework which is better suited for integration tests
|
||||||
command = await get_search_area.ainvoke(
|
command = await get_search_area.ainvoke(
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Tests for NAIP tool."""
|
||||||
|
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -19,7 +21,6 @@ async def test_fetch_naip():
|
|||||||
- Internet access (to reach Planetary Computer STAC + blobs)
|
- Internet access (to reach Planetary Computer STAC + blobs)
|
||||||
- Planetary Computer / NAIP service to be up
|
- Planetary Computer / NAIP service to be up
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
||||||
# N 38°54'28" W 76°59'54"
|
# N 38°54'28" W 76°59'54"
|
||||||
# We'll use a small neighborhood AOI around that point.
|
# We'll use a small neighborhood AOI around that point.
|
||||||
@@ -60,7 +61,6 @@ async def test_fetch_naip_too_large():
|
|||||||
- Internet access (to reach Planetary Computer STAC + blobs)
|
- Internet access (to reach Planetary Computer STAC + blobs)
|
||||||
- Planetary Computer / NAIP service to be up
|
- Planetary Computer / NAIP service to be up
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
||||||
# N 38°54'28" W 76°59'54"
|
# N 38°54'28" W 76°59'54"
|
||||||
# We'll use a small neighborhood AOI around that point.
|
# We'll use a small neighborhood AOI around that point.
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Tests for Overture tool."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
@@ -57,6 +59,7 @@ def geo_assistant_with_buffer_fixture():
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_place():
|
async def test_get_place():
|
||||||
|
"""Ensure that `get_place` tool returns an Overture place given a place_name."""
|
||||||
command = await get_place.ainvoke(
|
command = await get_place.ainvoke(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
name="get_place",
|
name="get_place",
|
||||||
@@ -69,6 +72,10 @@ async def test_get_place():
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_places_within_buffer(geo_assistant_with_buffer_fixture):
|
async def test_get_places_within_buffer(geo_assistant_with_buffer_fixture):
|
||||||
|
"""
|
||||||
|
Ensure that `get_places_within_buffer` tool returns multiple Overture places that
|
||||||
|
fit match the category 'cafe' within a specific buffer area around a location.
|
||||||
|
"""
|
||||||
command = await get_places_within_buffer.ainvoke(
|
command = await get_places_within_buffer.ainvoke(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
name="get_places_within_buffer",
|
name="get_places_within_buffer",
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ TEST_IMAGE_URL = "https://petapixel.com/assets/uploads/2022/08/French-Officials-
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_summarize_sat_img(img_url, summary):
|
async def test_summarize_sat_img(img_url, summary):
|
||||||
|
"""
|
||||||
|
Ensure that the `summarize_sat_img` tool can describe a satellite image in JPEG
|
||||||
|
format.
|
||||||
|
"""
|
||||||
# Load the image from the supplied URL and encode it in base64
|
# Load the image from the supplied URL and encode it in base64
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
||||||
|
|||||||
Reference in New Issue
Block a user