mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-15 15:31:02 +02:00
Enable pydocstyle (D) ruff rule and add more docs (#21)
* Use pydocstyle (D) rule with google convention Add a ruff rule to catch missing documentation. Using google convention so that undocumented-param (D417) rule is enabled to catch missing params, xref https://docs.astral.sh/ruff/rules/undocumented-param. Extended to include D213 (instead of D212) and D410 rules too. * Fix D100 Missing docstring in public module * Fix D101 Missing docstring in public class * Fix D103 Missing docstring in public function Also ignore rule D205 to allow first sentence of docstring to wrap to multiple lines. * Fix D417 Missing argument description in the docstring * Update indent in pyproject.toml file
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
"""Create agent graph that calls tools."""
|
||||
|
||||
import datetime
|
||||
|
||||
from langchain.agents import create_agent
|
||||
@@ -31,6 +33,7 @@ The current date and time is {now}.
|
||||
|
||||
|
||||
async def create_graph():
|
||||
"""Create langchain agent graph with a list of tools."""
|
||||
checkpointer = InMemorySaver()
|
||||
graph = create_agent(
|
||||
model=llm,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Ollama chat model."""
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""State schema for the geo-assistant agent."""
|
||||
|
||||
from typing import NotRequired
|
||||
|
||||
from geojson_pydantic import Feature, FeatureCollection
|
||||
@@ -6,6 +8,8 @@ from pydantic import Field
|
||||
|
||||
|
||||
class GeoAssistantState(AgentState):
|
||||
"""Schema for the geo-assistant agent's state."""
|
||||
|
||||
place: NotRequired[Feature | None] = None
|
||||
search_area: NotRequired[Feature | None] = None
|
||||
places_within_buffer: NotRequired[FeatureCollection | None] = None
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Chat app API endpoint."""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import aclosing, asynccontextmanager
|
||||
@@ -21,12 +23,12 @@ UI_SET_FIELDS_WHITELIST = ["point", "messages"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def _lifespan(app: FastAPI):
|
||||
app.state.chatbot = await create_graph()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Geo Assistant", lifespan=lifespan)
|
||||
app = FastAPI(title="Geo Assistant", lifespan=_lifespan)
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
@@ -44,6 +46,7 @@ async def stream_chat(
|
||||
chatbot: Any,
|
||||
request: Request,
|
||||
) -> AsyncGenerator[bytes]:
|
||||
"""Agent chat stream."""
|
||||
config: dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": str(thread_id),
|
||||
@@ -101,6 +104,7 @@ async def stream_chat(
|
||||
|
||||
@app.post("/chat")
|
||||
async def chat(request: ChatRequestBody, http_request: Request) -> StreamingResponse:
|
||||
"""HTTP POST endpoint at /chat."""
|
||||
generator = stream_chat(
|
||||
ui_state_update=request.agent_state_input,
|
||||
thread_id=request.thread_id,
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
"""Chat API schemas."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
"""Schema for the request to the Chat API."""
|
||||
|
||||
thread_id: str
|
||||
agent_state_input: GeoAssistantState
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Schema for the response from the Chat API."""
|
||||
|
||||
thread_id: str
|
||||
state: GeoAssistantState
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Chat app frontend."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""List of tools available to the agent."""
|
||||
|
||||
from geo_assistant.tools.buffer import get_search_area
|
||||
from geo_assistant.tools.naip import fetch_naip_img
|
||||
from geo_assistant.tools.overture import get_place, get_places_within_buffer
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Tool to create a buffer polygon around a geometry feature."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
import geopandas as gpd
|
||||
@@ -17,8 +19,14 @@ async def get_search_area(
|
||||
state: Annotated[GeoAssistantState, InjectedState],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||
) -> Command:
|
||||
"""Get a search area buffer in km around the place defined in the agent state."""
|
||||
"""
|
||||
Get a search area buffer in km around the place defined in the agent state.
|
||||
|
||||
Args:
|
||||
buffer_size_km: Radius of the buffer in kilometres.
|
||||
state: Pass in 'place' as state into this agent.
|
||||
tool_call_id: Optional ID for tracking the tool call.
|
||||
"""
|
||||
place_feature = state.get("place")
|
||||
|
||||
if not place_feature:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# tools/naip_mpc_tools.py
|
||||
"""Tool to query Planetary Computer STAC API for NAIP imagery."""
|
||||
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
@@ -39,7 +40,8 @@ async def fetch_naip_img(
|
||||
Args:
|
||||
start_date: Start date (YYYY-MM-DD).
|
||||
end_date: End date (YYYY-MM-DD).
|
||||
|
||||
state: Pass in search_area as state into this agent.
|
||||
tool_call_id: Optional ID for tracking the tool call
|
||||
"""
|
||||
if not state["search_area"]:
|
||||
return Command(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Tool to find closest matching Overture place based on user input."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated
|
||||
@@ -21,7 +23,8 @@ load_dotenv()
|
||||
|
||||
|
||||
def create_database_connection():
|
||||
"""Create and configure a DuckDB connection with necessary extensions.
|
||||
"""
|
||||
Create and configure a DuckDB connection with necessary extensions.
|
||||
|
||||
Args:
|
||||
database_path: Path to the DuckDB database file
|
||||
@@ -43,8 +46,14 @@ async def get_place(
|
||||
place_name: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||
) -> Command:
|
||||
"""Get place location from Overture Maps based on user input place name."""
|
||||
"""
|
||||
Get place location from Overture Maps based on user input place name.
|
||||
|
||||
Args:
|
||||
place_name: An address or location given as a human-readable string.
|
||||
tool_call_id: Optional ID for tracking the tool call.
|
||||
|
||||
"""
|
||||
db_connection = create_database_connection()
|
||||
source = os.getenv("OVERTURE_SOURCE", "local")
|
||||
if source == "s3":
|
||||
@@ -162,10 +171,16 @@ async def get_places_within_buffer(
|
||||
state: Annotated[GeoAssistantState, InjectedState],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
"""Get places from Overture Maps within user specified area and user specified Overture place type.
|
||||
|
||||
Accepts: restaurant(s), cafe(s), coffee shop(s), bar(s), pub(s) - case insensitive."""
|
||||
"""
|
||||
Get places from Overture Maps within user specified area and user specified Overture
|
||||
place type.
|
||||
|
||||
Args:
|
||||
place: Overture place type. Accepts: restaurant(s), cafe(s), coffee shop(s),
|
||||
bar(s), pub(s) - case insensitive.
|
||||
state: Pass in 'search_area' as state into this agent.
|
||||
tool_call_id: Optional ID for tracking the tool call.
|
||||
"""
|
||||
# Normalize the place type
|
||||
place = normalize_place_type(place)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ dotenv.load_dotenv()
|
||||
|
||||
|
||||
class SatImgSummary(dspy.Signature):
|
||||
"Describe things you see in the satellite image."
|
||||
"""Describe things you see in the satellite image."""
|
||||
|
||||
img: dspy.Image = dspy.InputField(desc="A satellite image")
|
||||
answer: str = dspy.OutputField(desc="Description of the image")
|
||||
@@ -33,7 +33,8 @@ class SatImgSummaryAgent(dspy.Module):
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 4_096,
|
||||
) -> None:
|
||||
"""Initialize the satellite image summary agent.
|
||||
"""
|
||||
Initialize the satellite image summary agent.
|
||||
|
||||
Args:
|
||||
model: The Ollama model to use for summarization
|
||||
@@ -53,7 +54,8 @@ class SatImgSummaryAgent(dspy.Module):
|
||||
self.summarizer = dspy.Predict(SatImgSummary)
|
||||
|
||||
def forward(self, img_url: str) -> dspy.Prediction:
|
||||
"""Generate a summary for the given image URL.
|
||||
"""
|
||||
Generate a summary for the given image URL.
|
||||
|
||||
Args:
|
||||
img_url: URL of the image to summarize
|
||||
@@ -73,7 +75,12 @@ async def summarize_sat_img(
|
||||
state: Annotated[GeoAssistantState, InjectedState],
|
||||
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||
) -> Command:
|
||||
"""Summarize the contents of a satellite image using an LLM.
|
||||
"""
|
||||
Summarize the contents of a satellite image using an LLM.
|
||||
|
||||
Args:
|
||||
state: Pass in 'naip_img_bytes' as state into this agent.
|
||||
tool_call_id: Optional ID for tracking the tool call.
|
||||
|
||||
Returns:
|
||||
Command containing the image summary and metadata
|
||||
|
||||
Reference in New Issue
Block a user