Enable pydocstyle (D) ruff rule and add more docs (#21)

* Use pydocstyle (D) rule with google convention

Add a ruff rule to catch missing documentation. Using google convention so that undocumented-param (D417) rule is enabled to catch missing params, xref https://docs.astral.sh/ruff/rules/undocumented-param. Extended to include D213 (instead of D212) and D410 rules too.

* Fix D100 Missing docstring in public module

* Fix D101 Missing docstring in public class

* Fix D103 Missing docstring in public function

Also ignore rule D205 to allow first sentence of docstring to wrap to multiple lines.

* Fix  D417 Missing argument description in the docstring

* Update indent in pyproject.toml file
This commit is contained in:
Wei Ji
2025-12-09 21:22:01 +13:00
committed by GitHub
parent 5782d890d6
commit e3373026d6
17 changed files with 104 additions and 17 deletions
+14
View File
@@ -59,8 +59,22 @@ asyncio_mode = "auto"
[tool.ruff.lint] [tool.ruff.lint]
select = [ select = [
"COM", # flake8-commas "COM", # flake8-commas
"D", # pydocstyle
"F", # pyflakes "F", # pyflakes
"I", # isort "I", # isort
"RUF", # ruff-specific "RUF", # ruff-specific
"UP", # pyupgrade "UP", # pyupgrade
] ]
extend-select = [
"D213", # Summary lines should be positioned on the second physical line of the docstring.
"D410", # A blank line after section headings.
]
ignore = [
"D205", # 1 blank line required between summary line and description
"D212", # Multi-line docstring summary should start at the first line
]
[tool.ruff.lint.pydocstyle]
# See https://docs.astral.sh/ruff/faq/#does-ruff-support-numpy-or-google-style-docstrings
# for the enabled/disabled rules for the "google" convention.
convention = "google"
+3
View File
@@ -1,3 +1,5 @@
"""Create agent graph that calls tools."""
import datetime import datetime
from langchain.agents import create_agent from langchain.agents import create_agent
@@ -31,6 +33,7 @@ The current date and time is {now}.
async def create_graph(): async def create_graph():
"""Create langchain agent graph with a list of tools."""
checkpointer = InMemorySaver() checkpointer = InMemorySaver()
graph = create_agent( graph = create_agent(
model=llm, model=llm,
+2
View File
@@ -1,3 +1,5 @@
"""Ollama chat model."""
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
+4
View File
@@ -1,3 +1,5 @@
"""State schema for the geo-assistant agent."""
from typing import NotRequired from typing import NotRequired
from geojson_pydantic import Feature, FeatureCollection from geojson_pydantic import Feature, FeatureCollection
@@ -6,6 +8,8 @@ from pydantic import Field
class GeoAssistantState(AgentState): class GeoAssistantState(AgentState):
"""Schema for the geo-assistant agent's state."""
place: NotRequired[Feature | None] = None place: NotRequired[Feature | None] = None
search_area: NotRequired[Feature | None] = None search_area: NotRequired[Feature | None] = None
places_within_buffer: NotRequired[FeatureCollection | None] = None places_within_buffer: NotRequired[FeatureCollection | None] = None
+6 -2
View File
@@ -1,3 +1,5 @@
"""Chat app API endpoint."""
import logging import logging
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import aclosing, asynccontextmanager from contextlib import aclosing, asynccontextmanager
@@ -21,12 +23,12 @@ UI_SET_FIELDS_WHITELIST = ["point", "messages"]
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def _lifespan(app: FastAPI):
app.state.chatbot = await create_graph() app.state.chatbot = await create_graph()
yield yield
app = FastAPI(title="Geo Assistant", lifespan=lifespan) app = FastAPI(title="Geo Assistant", lifespan=_lifespan)
app.add_middleware( app.add_middleware(
@@ -44,6 +46,7 @@ async def stream_chat(
chatbot: Any, chatbot: Any,
request: Request, request: Request,
) -> AsyncGenerator[bytes]: ) -> AsyncGenerator[bytes]:
"""Agent chat stream."""
config: dict[str, Any] = { config: dict[str, Any] = {
"configurable": { "configurable": {
"thread_id": str(thread_id), "thread_id": str(thread_id),
@@ -101,6 +104,7 @@ async def stream_chat(
@app.post("/chat") @app.post("/chat")
async def chat(request: ChatRequestBody, http_request: Request) -> StreamingResponse: async def chat(request: ChatRequestBody, http_request: Request) -> StreamingResponse:
"""HTTP POST endpoint at /chat."""
generator = stream_chat( generator = stream_chat(
ui_state_update=request.agent_state_input, ui_state_update=request.agent_state_input,
thread_id=request.thread_id, thread_id=request.thread_id,
+6
View File
@@ -1,13 +1,19 @@
"""Chat API schemas."""
from pydantic import BaseModel from pydantic import BaseModel
from geo_assistant.agent.state import GeoAssistantState from geo_assistant.agent.state import GeoAssistantState
class ChatRequestBody(BaseModel): class ChatRequestBody(BaseModel):
"""Schema for the request to the Chat API."""
thread_id: str thread_id: str
agent_state_input: GeoAssistantState agent_state_input: GeoAssistantState
class ChatResponse(BaseModel): class ChatResponse(BaseModel):
"""Schema for the response from the Chat API."""
thread_id: str thread_id: str
state: GeoAssistantState state: GeoAssistantState
+2
View File
@@ -1,3 +1,5 @@
"""Chat app frontend."""
import base64 import base64
import json import json
import os import os
+2
View File
@@ -1,3 +1,5 @@
"""List of tools available to the agent."""
from geo_assistant.tools.buffer import get_search_area from geo_assistant.tools.buffer import get_search_area
from geo_assistant.tools.naip import fetch_naip_img from geo_assistant.tools.naip import fetch_naip_img
from geo_assistant.tools.overture import get_place, get_places_within_buffer from geo_assistant.tools.overture import get_place, get_places_within_buffer
+9 -1
View File
@@ -1,3 +1,5 @@
"""Tool to create a buffer polygon around a geometry feature."""
from typing import Annotated from typing import Annotated
import geopandas as gpd import geopandas as gpd
@@ -17,8 +19,14 @@ async def get_search_area(
state: Annotated[GeoAssistantState, InjectedState], state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId] = "", tool_call_id: Annotated[str, InjectedToolCallId] = "",
) -> Command: ) -> Command:
"""Get a search area buffer in km around the place defined in the agent state.""" """
Get a search area buffer in km around the place defined in the agent state.
Args:
buffer_size_km: Radius of the buffer in kilometres.
state: Pass in 'place' as state into this agent.
tool_call_id: Optional ID for tracking the tool call.
"""
place_feature = state.get("place") place_feature = state.get("place")
if not place_feature: if not place_feature:
+4 -2
View File
@@ -1,4 +1,5 @@
# tools/naip_mpc_tools.py """Tool to query Planetary Computer STAC API for NAIP imagery."""
import base64 import base64
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from io import BytesIO from io import BytesIO
@@ -39,7 +40,8 @@ async def fetch_naip_img(
Args: Args:
start_date: Start date (YYYY-MM-DD). start_date: Start date (YYYY-MM-DD).
end_date: End date (YYYY-MM-DD). end_date: End date (YYYY-MM-DD).
state: Pass in search_area as state into this agent.
tool_call_id: Optional ID for tracking the tool call
""" """
if not state["search_area"]: if not state["search_area"]:
return Command( return Command(
+20 -5
View File
@@ -1,3 +1,5 @@
"""Tool to find closest matching Overture place based on user input."""
import json import json
import os import os
from typing import Annotated from typing import Annotated
@@ -21,7 +23,8 @@ load_dotenv()
def create_database_connection(): def create_database_connection():
"""Create and configure a DuckDB connection with necessary extensions. """
Create and configure a DuckDB connection with necessary extensions.
Args: Args:
database_path: Path to the DuckDB database file database_path: Path to the DuckDB database file
@@ -43,8 +46,14 @@ async def get_place(
place_name: str, place_name: str,
tool_call_id: Annotated[str, InjectedToolCallId] = "", tool_call_id: Annotated[str, InjectedToolCallId] = "",
) -> Command: ) -> Command:
"""Get place location from Overture Maps based on user input place name.""" """
Get place location from Overture Maps based on user input place name.
Args:
place_name: An address or location given as a human-readable string.
tool_call_id: Optional ID for tracking the tool call.
"""
db_connection = create_database_connection() db_connection = create_database_connection()
source = os.getenv("OVERTURE_SOURCE", "local") source = os.getenv("OVERTURE_SOURCE", "local")
if source == "s3": if source == "s3":
@@ -162,10 +171,16 @@ async def get_places_within_buffer(
state: Annotated[GeoAssistantState, InjectedState], state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId], tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command: ) -> Command:
"""Get places from Overture Maps within user specified area and user specified Overture place type. """
Get places from Overture Maps within user specified area and user specified Overture
Accepts: restaurant(s), cafe(s), coffee shop(s), bar(s), pub(s) - case insensitive.""" place type.
Args:
place: Overture place type. Accepts: restaurant(s), cafe(s), coffee shop(s),
bar(s), pub(s) - case insensitive.
state: Pass in 'search_area' as state into this agent.
tool_call_id: Optional ID for tracking the tool call.
"""
# Normalize the place type # Normalize the place type
place = normalize_place_type(place) place = normalize_place_type(place)
+11 -4
View File
@@ -17,7 +17,7 @@ dotenv.load_dotenv()
class SatImgSummary(dspy.Signature): class SatImgSummary(dspy.Signature):
"Describe things you see in the satellite image." """Describe things you see in the satellite image."""
img: dspy.Image = dspy.InputField(desc="A satellite image") img: dspy.Image = dspy.InputField(desc="A satellite image")
answer: str = dspy.OutputField(desc="Description of the image") answer: str = dspy.OutputField(desc="Description of the image")
@@ -33,7 +33,8 @@ class SatImgSummaryAgent(dspy.Module):
temperature: float = 0.5, temperature: float = 0.5,
max_tokens: int = 4_096, max_tokens: int = 4_096,
) -> None: ) -> None:
"""Initialize the satellite image summary agent. """
Initialize the satellite image summary agent.
Args: Args:
model: The Ollama model to use for summarization model: The Ollama model to use for summarization
@@ -53,7 +54,8 @@ class SatImgSummaryAgent(dspy.Module):
self.summarizer = dspy.Predict(SatImgSummary) self.summarizer = dspy.Predict(SatImgSummary)
def forward(self, img_url: str) -> dspy.Prediction: def forward(self, img_url: str) -> dspy.Prediction:
"""Generate a summary for the given image URL. """
Generate a summary for the given image URL.
Args: Args:
img_url: URL of the image to summarize img_url: URL of the image to summarize
@@ -73,7 +75,12 @@ async def summarize_sat_img(
state: Annotated[GeoAssistantState, InjectedState], state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str | None, InjectedToolCallId] = None, tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
) -> Command: ) -> Command:
"""Summarize the contents of a satellite image using an LLM. """
Summarize the contents of a satellite image using an LLM.
Args:
state: Pass in 'naip_img_bytes' as state into this agent.
tool_call_id: Optional ID for tracking the tool call.
Returns: Returns:
Command containing the image summary and metadata Command containing the image summary and metadata
+4 -1
View File
@@ -1,3 +1,5 @@
"""Tests for chat API endpoint."""
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
@@ -10,7 +12,7 @@ from geo_assistant.api.app import app
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def initialized_app(): async def initialized_app():
"""Initialize the app's chatbot before testing""" """Initialize the app's chatbot before testing."""
# Manually initialize the chatbot as the lifespan would # Manually initialize the chatbot as the lifespan would
app.state.chatbot = await create_graph() app.state.chatbot = await create_graph()
yield app yield app
@@ -21,6 +23,7 @@ async def initialized_app():
@pytest.mark.xfail @pytest.mark.xfail
async def test_call_api(initialized_app): async def test_call_api(initialized_app):
"""Test calling the API at the /chat HTTP POST endpoint."""
async with AsyncClient( async with AsyncClient(
transport=ASGITransport(app=initialized_app), transport=ASGITransport(app=initialized_app),
base_url="http://test", base_url="http://test",
+4
View File
@@ -1,3 +1,5 @@
"""Tests for buffer tool."""
from geojson_pydantic import Feature, Point from geojson_pydantic import Feature, Point
from langchain_core.tools.base import ToolCall from langchain_core.tools.base import ToolCall
from pytest import fixture from pytest import fixture
@@ -8,6 +10,7 @@ from geo_assistant.tools.buffer import get_search_area
@fixture @fixture
def geo_assistant_fixture(): def geo_assistant_fixture():
"""Fixture with a GeoJSON point feature in a GeoAssistantState."""
place_geojson = Feature( place_geojson = Feature(
type="Feature", type="Feature",
geometry=Point(type="Point", coordinates=[-9.1393, 38.7223]), geometry=Point(type="Point", coordinates=[-9.1393, 38.7223]),
@@ -22,6 +25,7 @@ def geo_assistant_fixture():
async def test_get_search_area(geo_assistant_fixture): async def test_get_search_area(geo_assistant_fixture):
"""Ensure that `get_search_area` tool returns a buffer Polygon."""
# Call the underlying function directly to test the logic # Call the underlying function directly to test the logic
# This bypasses the injection framework which is better suited for integration tests # This bypasses the injection framework which is better suited for integration tests
command = await get_search_area.ainvoke( command = await get_search_area.ainvoke(
+2 -2
View File
@@ -1,3 +1,5 @@
"""Tests for NAIP tool."""
from types import NoneType from types import NoneType
import pytest import pytest
@@ -19,7 +21,6 @@ async def test_fetch_naip():
- Internet access (to reach Planetary Computer STAC + blobs) - Internet access (to reach Planetary Computer STAC + blobs)
- Planetary Computer / NAIP service to be up - Planetary Computer / NAIP service to be up
""" """
# Union Market coordinates from GeoNames: 38.90789, -76.99831 # Union Market coordinates from GeoNames: 38.90789, -76.99831
# N 38°54'28" W 76°59'54" # N 38°54'28" W 76°59'54"
# We'll use a small neighborhood AOI around that point. # We'll use a small neighborhood AOI around that point.
@@ -60,7 +61,6 @@ async def test_fetch_naip_too_large():
- Internet access (to reach Planetary Computer STAC + blobs) - Internet access (to reach Planetary Computer STAC + blobs)
- Planetary Computer / NAIP service to be up - Planetary Computer / NAIP service to be up
""" """
# Union Market coordinates from GeoNames: 38.90789, -76.99831 # Union Market coordinates from GeoNames: 38.90789, -76.99831
# N 38°54'28" W 76°59'54" # N 38°54'28" W 76°59'54"
# We'll use a small neighborhood AOI around that point. # We'll use a small neighborhood AOI around that point.
+7
View File
@@ -1,3 +1,5 @@
"""Tests for Overture tool."""
import os import os
import geopandas as gpd import geopandas as gpd
@@ -57,6 +59,7 @@ def geo_assistant_with_buffer_fixture():
async def test_get_place(): async def test_get_place():
"""Ensure that `get_place` tool returns an Overture place given a place_name."""
command = await get_place.ainvoke( command = await get_place.ainvoke(
ToolCall( ToolCall(
name="get_place", name="get_place",
@@ -69,6 +72,10 @@ async def test_get_place():
async def test_get_places_within_buffer(geo_assistant_with_buffer_fixture): async def test_get_places_within_buffer(geo_assistant_with_buffer_fixture):
"""
Ensure that `get_places_within_buffer` tool returns multiple Overture places that
fit match the category 'cafe' within a specific buffer area around a location.
"""
command = await get_places_within_buffer.ainvoke( command = await get_places_within_buffer.ainvoke(
ToolCall( ToolCall(
name="get_places_within_buffer", name="get_places_within_buffer",
+4
View File
@@ -23,6 +23,10 @@ TEST_IMAGE_URL = "https://petapixel.com/assets/uploads/2022/08/French-Officials-
], ],
) )
async def test_summarize_sat_img(img_url, summary): async def test_summarize_sat_img(img_url, summary):
"""
Ensure that the `summarize_sat_img` tool can describe a satellite image in JPEG
format.
"""
# Load the image from the supplied URL and encode it in base64 # Load the image from the supplied URL and encode it in base64
headers = { headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",