Use ruff rules COM, F, I, RUF, UP (#17)

* Use pyupgrade (UP) rule

* Use pyflakes (F) rule

* Use isort (I) rule

* Use ruff-specific (RUF) rules

* Use flake8-commas (COM) rule

* Fix UP043 Unnecessary default type arguments
This commit is contained in:
Wei Ji
2025-12-05 10:44:37 +00:00
committed by GitHub
parent 7c97b475e4
commit e3530cefd2
16 changed files with 103 additions and 78 deletions
+9
View File
@@ -47,3 +47,12 @@ packages = ["src/geo_assistant"]
[tool.pytest.ini_options]
addopts = "--color=yes"
asyncio_mode = "auto"
[tool.ruff.lint]
select = [
"COM", # flake8-commas
"F", # pyflakes
"I", # isort
"RUF", # ruff-specific
"UP", # pyupgrade
]
+5 -4
View File
@@ -1,12 +1,13 @@
import datetime
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents import create_agent
from geo_assistant.agent.state import GeoAssistantState
from langgraph.checkpoint.memory import InMemorySaver
from geo_assistant.agent.llms import llm
from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.tools.buffer import get_search_area
from geo_assistant.tools.naip import fetch_naip_img
from geo_assistant.tools.overture import get_place
from geo_assistant.tools.buffer import get_search_area
from geo_assistant.tools.summarize import summarize_sat_img
SYSTEM_PROMPT = """
@@ -36,7 +37,7 @@ async def create_graph():
summarize_sat_img,
],
system_prompt=SYSTEM_PROMPT.format(
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
),
state_schema=GeoAssistantState,
checkpointer=checkpointer,
+1 -1
View File
@@ -1,6 +1,6 @@
import os
from dotenv import load_dotenv
from dotenv import load_dotenv
from langchain_ollama import ChatOllama
# Load environment variables from env file
+5 -3
View File
@@ -1,6 +1,7 @@
from langchain.agents import AgentState
from typing import NotRequired
from geojson_pydantic import Feature
from typing_extensions import NotRequired
from langchain.agents import AgentState
from pydantic import Field
@@ -8,5 +9,6 @@ class GeoAssistantState(AgentState):
place: NotRequired[Feature | None] = None
search_area: NotRequired[Feature | None] = None
naip_png_path: NotRequired[str | None] = Field(
default=None, description="Path to the saved NAIP RGB PNG image"
default=None,
description="Path to the saved NAIP RGB PNG image",
)
+7 -6
View File
@@ -1,11 +1,12 @@
import json
import logging
from collections.abc import AsyncGenerator
from contextlib import aclosing, asynccontextmanager
from typing import Any, AsyncGenerator, Dict
from typing import Any
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import logging
from pydantic import UUID4
from geo_assistant.agent.graph import create_graph
@@ -43,11 +44,11 @@ async def stream_chat(
thread_id: UUID4,
chatbot: Any,
request: Request,
) -> AsyncGenerator[bytes, None]:
config: Dict[str, Any] = {
) -> AsyncGenerator[bytes]:
config: dict[str, Any] = {
"configurable": {
"thread_id": str(thread_id),
}
},
}
state_updates = {}
@@ -69,7 +70,7 @@ async def stream_chat(
{
"content": f"Manually selected data for field {key}: {description}",
"type": "human",
}
},
)
# Add UI messages to the existing messages if they exist
+1
View File
@@ -1,4 +1,5 @@
from pydantic import BaseModel
from geo_assistant.agent.state import GeoAssistantState
+5 -2
View File
@@ -2,10 +2,10 @@ import json
import os
import uuid
import folium
import httpx
import streamlit as st
import streamlit.components.v1 as components
import folium
from dotenv import load_dotenv
# Load environment variables from .env file
@@ -137,7 +137,10 @@ def stream_chat(user_message: str):
# Fit map to bounds if we have coordinates
if all_lons and all_lats:
m.fit_bounds(
[[min(all_lats), min(all_lons)], [max(all_lats), max(all_lons)]]
[
[min(all_lats), min(all_lons)],
[max(all_lats), max(all_lons)],
],
)
# Display the map
+13 -11
View File
@@ -1,12 +1,14 @@
from typing import Annotated
import geopandas as gpd
from langgraph.types import Command
from langgraph.prebuilt import InjectedState
from langchain_core.tools.base import InjectedToolCallId
from geojson_pydantic import Feature
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from typing import Annotated
from langchain_core.tools.base import InjectedToolCallId
from langgraph.prebuilt import InjectedState
from langgraph.types import Command
from geo_assistant.agent.state import GeoAssistantState
from geojson_pydantic import Feature
@tool
@@ -26,9 +28,9 @@ async def get_search_area(
ToolMessage(
content="No place defined in the agent state to create a search area around.",
tool_call_id=tool_call_id,
)
),
],
}
},
)
# Convert GeoJSON feature to GeoDataFrame
@@ -38,7 +40,7 @@ async def get_search_area(
gdf_m = gdf.to_crs(epsg=3857) # latlon to Web Mercator for meter-based buffering
gdf_m["geometry"] = gdf_m["geometry"].buffer(
buffer_size_km * 1000
buffer_size_km * 1000,
) # Buffer in meters
gdf = gdf_m.to_crs(epsg=4326) # Back to WGS84
@@ -46,7 +48,7 @@ async def get_search_area(
if len(gdf) != 1:
raise ValueError(
f"{len(gdf)} features found after buffer operation, should be just 1. "
"Was a Multi-Point/LineString/Polygon geometry passed in?"
"Was a Multi-Point/LineString/Polygon geometry passed in?",
)
buffer_feature = Feature(
type="Feature",
@@ -61,7 +63,7 @@ async def get_search_area(
ToolMessage(
content=f"Created search area geometry buffer of {buffer_size_km} km around the place.",
tool_call_id=tool_call_id,
)
),
],
}
},
)
+20 -21
View File
@@ -1,19 +1,18 @@
# tools/naip_mpc_tools.py
from typing import Dict, Any, Optional, Annotated
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from langchain_core.tools import tool
from pystac_client import Client
from odc.stac import stac_load
from langgraph.types import Command
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import InjectedToolCallId
from pathlib import Path
from typing import Annotated, Any
import dotenv
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.types import Command
from odc.stac import stac_load
from pystac_client import Client
dotenv.load_dotenv()
@@ -23,10 +22,10 @@ E84_STAC_URL = "https://earth-search.aws.element84.com/v1"
@tool("fetch_naip_img")
async def fetch_naip_img(
aoi_geojson: Dict[str, Any],
aoi_geojson: dict[str, Any],
start_date: str,
end_date: str,
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
) -> Command:
"""
Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and
@@ -56,10 +55,10 @@ async def fetch_naip_img(
ToolMessage(
content="No NAIP imagery found for the specified area and date range.",
tool_call_id=tool_call_id,
)
),
],
"naip_png_path": None,
}
},
)
# --- 2. Load as xarray cube with odc.stac ---
@@ -81,14 +80,14 @@ async def fetch_naip_img(
ToolMessage(
content="Unable to load NAIP RGB image, dataset has no time dimension",
tool_call_id=tool_call_id,
)
),
],
"naip_png_path": None,
}
},
)
# --- 3. Build an RGB composite from the cube ---
# For the PNG, well just use the first time slice (you can swap in “latest”
# For the PNG, we'll just use the first time slice (you can swap in “latest”
# or a temporal reduction if you prefer).
red = ds["Red"].isel(time=0)
green = ds["Green"].isel(time=0)
@@ -121,8 +120,8 @@ async def fetch_naip_img(
ToolMessage(
content=f"NAIP RGB image saved to {out_path.as_posix()}",
tool_call_id=tool_call_id,
)
),
],
"naip_png_path": out_path.as_posix(),
}
},
)
+5 -4
View File
@@ -4,11 +4,11 @@ from typing import Annotated
import duckdb
from dotenv import load_dotenv
from geojson_pydantic import Feature
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.types import Command
from geojson_pydantic import Feature
# Load environment variables
load_dotenv()
@@ -34,7 +34,8 @@ def create_database_connection():
@tool
async def get_place(
place_name: str, tool_call_id: Annotated[str, InjectedToolCallId] = ""
place_name: str,
tool_call_id: Annotated[str, InjectedToolCallId] = "",
) -> Command:
"""Get place location from Overture Maps based on user input place name."""
@@ -65,7 +66,7 @@ async def get_place(
WHERE jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) > 0.5
ORDER BY similarity_score DESC
LIMIT 1;
"""
""",
).fetchall()
db_connection.close()
@@ -89,7 +90,7 @@ async def get_place(
ToolMessage(
content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}",
tool_call_id=tool_call_id,
)
),
],
},
)
+10 -10
View File
@@ -1,14 +1,14 @@
"""Tools for summarizing satellite images using LLM-based analysis."""
import os
from typing import Annotated, Optional
import dspy
from langchain_core.tools import tool
from langgraph.types import Command
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import InjectedToolCallId
from typing import Annotated
import dotenv
import dspy
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.types import Command
dotenv.load_dotenv()
@@ -68,7 +68,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent()
@tool
async def summarize_sat_img(
img_url: str,
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
) -> Command:
"""Summarize the contents of a satellite image using an LLM.
@@ -96,7 +96,7 @@ async def summarize_sat_img(
content=message_content,
artifact=artifact,
tool_call_id=tool_call_id,
)
]
}
),
],
},
)
+8 -6
View File
@@ -1,10 +1,11 @@
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from uuid import uuid4
from geo_assistant.api.app import app
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from geo_assistant.agent.graph import create_graph
from geo_assistant.api.app import app
@pytest_asyncio.fixture
@@ -22,7 +23,8 @@ async def initialized_app():
async def test_hello_world(initialized_app):
"""Hello world test for the API"""
async with AsyncClient(
transport=ASGITransport(app=initialized_app), base_url="http://test"
transport=ASGITransport(app=initialized_app),
base_url="http://test",
) as client:
thread_id = uuid4()
response = await client.post(
@@ -33,7 +35,7 @@ async def test_hello_world(initialized_app):
{
"content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it",
"type": "human",
}
},
],
"place": None,
"search_area": None,
+4 -3
View File
@@ -1,8 +1,9 @@
from pytest import fixture
from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.tools.buffer import get_search_area
from geojson_pydantic import Feature, Point
from langchain_core.tools.base import ToolCall
from pytest import fixture
from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.tools.buffer import get_search_area
@fixture
+5 -4
View File
@@ -1,7 +1,8 @@
import pytest
from pathlib import Path
from shapely.geometry import box, mapping
import pytest
from langchain_core.tools.base import ToolCall
from shapely.geometry import box, mapping
from geo_assistant.tools.naip import fetch_naip_img
@@ -19,7 +20,7 @@ async def test_fetch_naip(tmp_path):
"""
# Union Market coordinates from GeoNames: 38.90789, -76.99831
# N 38°5428 W 76°5954
# N 38°54'28" W 76°59'54"
# We'll use a small neighborhood AOI around that point.
# :contentReference[oaicite:0]{index=0}
lat = 38.90789
@@ -44,7 +45,7 @@ async def test_fetch_naip(tmp_path):
id="test_tool_call_id",
)
# Call the actual tool no STAC / odc-stac mocking
# Call the actual tool - no STAC / odc-stac mocking
result = await fetch_naip_img.ainvoke(tool_call["args"])
# Basic sanity checks on result
+2 -1
View File
@@ -1,4 +1,5 @@
import os
import pytest
from langchain_core.tools.base import ToolCall
@@ -24,6 +25,6 @@ async def test_get_place():
type="tool_call",
id="test_id",
args={"place_name": "Neighbourhood Cafe Lisbon"},
)
),
)
assert "place" in command.update
+3 -2
View File
@@ -1,9 +1,10 @@
"""Tests for the satellite image summarization tool."""
import pytest
import uuid
import pytest
from langchain_core.tools.base import ToolCall
from geo_assistant.tools.summarize import summarize_sat_img
# Sample test data
@@ -24,7 +25,7 @@ def test_summarize_sat_img(img_url, summary):
type="tool_call",
args={"img_url": img_url},
id=str(uuid.uuid4()),
)
),
)
print(command.update.get("messages"))