Use ruff rules COM, F, I, RUF, UP (#17)

* Use pyupgrade (UP) rule

* Use pyflakes (F) rule

* Use isort (I) rule

* Use ruff-specific (RUF) rules

* Use flake8-commas (COM) rule

* Fix UP043 Unnecessary default type arguments
This commit is contained in:
Wei Ji
2025-12-05 10:44:37 +00:00
committed by GitHub
parent 7c97b475e4
commit e3530cefd2
16 changed files with 103 additions and 78 deletions
+9
View File
@@ -47,3 +47,12 @@ packages = ["src/geo_assistant"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "--color=yes" addopts = "--color=yes"
asyncio_mode = "auto" asyncio_mode = "auto"
[tool.ruff.lint]
select = [
"COM", # flake8-commas
"F", # pyflakes
"I", # isort
"RUF", # ruff-specific
"UP", # pyupgrade
]
+5 -4
View File
@@ -1,12 +1,13 @@
import datetime import datetime
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents import create_agent from langchain.agents import create_agent
from geo_assistant.agent.state import GeoAssistantState from langgraph.checkpoint.memory import InMemorySaver
from geo_assistant.agent.llms import llm from geo_assistant.agent.llms import llm
from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.tools.buffer import get_search_area
from geo_assistant.tools.naip import fetch_naip_img from geo_assistant.tools.naip import fetch_naip_img
from geo_assistant.tools.overture import get_place from geo_assistant.tools.overture import get_place
from geo_assistant.tools.buffer import get_search_area
from geo_assistant.tools.summarize import summarize_sat_img from geo_assistant.tools.summarize import summarize_sat_img
SYSTEM_PROMPT = """ SYSTEM_PROMPT = """
@@ -36,7 +37,7 @@ async def create_graph():
summarize_sat_img, summarize_sat_img,
], ],
system_prompt=SYSTEM_PROMPT.format( system_prompt=SYSTEM_PROMPT.format(
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
), ),
state_schema=GeoAssistantState, state_schema=GeoAssistantState,
checkpointer=checkpointer, checkpointer=checkpointer,
+1 -1
View File
@@ -1,6 +1,6 @@
import os import os
from dotenv import load_dotenv
from dotenv import load_dotenv
from langchain_ollama import ChatOllama from langchain_ollama import ChatOllama
# Load environment variables from env file # Load environment variables from env file
+5 -3
View File
@@ -1,6 +1,7 @@
from langchain.agents import AgentState from typing import NotRequired
from geojson_pydantic import Feature from geojson_pydantic import Feature
from typing_extensions import NotRequired from langchain.agents import AgentState
from pydantic import Field from pydantic import Field
@@ -8,5 +9,6 @@ class GeoAssistantState(AgentState):
place: NotRequired[Feature | None] = None place: NotRequired[Feature | None] = None
search_area: NotRequired[Feature | None] = None search_area: NotRequired[Feature | None] = None
naip_png_path: NotRequired[str | None] = Field( naip_png_path: NotRequired[str | None] = Field(
default=None, description="Path to the saved NAIP RGB PNG image" default=None,
description="Path to the saved NAIP RGB PNG image",
) )
+7 -6
View File
@@ -1,11 +1,12 @@
import json import json
import logging
from collections.abc import AsyncGenerator
from contextlib import aclosing, asynccontextmanager from contextlib import aclosing, asynccontextmanager
from typing import Any, AsyncGenerator, Dict from typing import Any
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import logging
from pydantic import UUID4 from pydantic import UUID4
from geo_assistant.agent.graph import create_graph from geo_assistant.agent.graph import create_graph
@@ -43,11 +44,11 @@ async def stream_chat(
thread_id: UUID4, thread_id: UUID4,
chatbot: Any, chatbot: Any,
request: Request, request: Request,
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes]:
config: Dict[str, Any] = { config: dict[str, Any] = {
"configurable": { "configurable": {
"thread_id": str(thread_id), "thread_id": str(thread_id),
} },
} }
state_updates = {} state_updates = {}
@@ -69,7 +70,7 @@ async def stream_chat(
{ {
"content": f"Manually selected data for field {key}: {description}", "content": f"Manually selected data for field {key}: {description}",
"type": "human", "type": "human",
} },
) )
# Add UI messages to the existing messages if they exist # Add UI messages to the existing messages if they exist
+1
View File
@@ -1,4 +1,5 @@
from pydantic import BaseModel from pydantic import BaseModel
from geo_assistant.agent.state import GeoAssistantState from geo_assistant.agent.state import GeoAssistantState
+5 -2
View File
@@ -2,10 +2,10 @@ import json
import os import os
import uuid import uuid
import folium
import httpx import httpx
import streamlit as st import streamlit as st
import streamlit.components.v1 as components import streamlit.components.v1 as components
import folium
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables from .env file # Load environment variables from .env file
@@ -137,7 +137,10 @@ def stream_chat(user_message: str):
# Fit map to bounds if we have coordinates # Fit map to bounds if we have coordinates
if all_lons and all_lats: if all_lons and all_lats:
m.fit_bounds( m.fit_bounds(
[[min(all_lats), min(all_lons)], [max(all_lats), max(all_lons)]] [
[min(all_lats), min(all_lons)],
[max(all_lats), max(all_lons)],
],
) )
# Display the map # Display the map
+13 -11
View File
@@ -1,12 +1,14 @@
from typing import Annotated
import geopandas as gpd import geopandas as gpd
from langgraph.types import Command from geojson_pydantic import Feature
from langgraph.prebuilt import InjectedState
from langchain_core.tools.base import InjectedToolCallId
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from typing import Annotated from langchain_core.tools.base import InjectedToolCallId
from langgraph.prebuilt import InjectedState
from langgraph.types import Command
from geo_assistant.agent.state import GeoAssistantState from geo_assistant.agent.state import GeoAssistantState
from geojson_pydantic import Feature
@tool @tool
@@ -26,9 +28,9 @@ async def get_search_area(
ToolMessage( ToolMessage(
content="No place defined in the agent state to create a search area around.", content="No place defined in the agent state to create a search area around.",
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
) ),
], ],
} },
) )
# Convert GeoJSON feature to GeoDataFrame # Convert GeoJSON feature to GeoDataFrame
@@ -38,7 +40,7 @@ async def get_search_area(
gdf_m = gdf.to_crs(epsg=3857) # latlon to Web Mercator for meter-based buffering gdf_m = gdf.to_crs(epsg=3857) # latlon to Web Mercator for meter-based buffering
gdf_m["geometry"] = gdf_m["geometry"].buffer( gdf_m["geometry"] = gdf_m["geometry"].buffer(
buffer_size_km * 1000 buffer_size_km * 1000,
) # Buffer in meters ) # Buffer in meters
gdf = gdf_m.to_crs(epsg=4326) # Back to WGS84 gdf = gdf_m.to_crs(epsg=4326) # Back to WGS84
@@ -46,7 +48,7 @@ async def get_search_area(
if len(gdf) != 1: if len(gdf) != 1:
raise ValueError( raise ValueError(
f"{len(gdf)} features found after buffer operation, should be just 1. " f"{len(gdf)} features found after buffer operation, should be just 1. "
"Was a Multi-Point/LineString/Polygon geometry passed in?" "Was a Multi-Point/LineString/Polygon geometry passed in?",
) )
buffer_feature = Feature( buffer_feature = Feature(
type="Feature", type="Feature",
@@ -61,7 +63,7 @@ async def get_search_area(
ToolMessage( ToolMessage(
content=f"Created search area geometry buffer of {buffer_size_km} km around the place.", content=f"Created search area geometry buffer of {buffer_size_km} km around the place.",
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
) ),
], ],
} },
) )
+20 -21
View File
@@ -1,19 +1,18 @@
# tools/naip_mpc_tools.py # tools/naip_mpc_tools.py
from typing import Dict, Any, Optional, Annotated
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import numpy as np from pathlib import Path
import xarray as xr from typing import Annotated, Any
import matplotlib.pyplot as plt
from langchain_core.tools import tool
from pystac_client import Client
from odc.stac import stac_load
from langgraph.types import Command
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import InjectedToolCallId
import dotenv import dotenv
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.types import Command
from odc.stac import stac_load
from pystac_client import Client
dotenv.load_dotenv() dotenv.load_dotenv()
@@ -23,10 +22,10 @@ E84_STAC_URL = "https://earth-search.aws.element84.com/v1"
@tool("fetch_naip_img") @tool("fetch_naip_img")
async def fetch_naip_img( async def fetch_naip_img(
aoi_geojson: Dict[str, Any], aoi_geojson: dict[str, Any],
start_date: str, start_date: str,
end_date: str, end_date: str,
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None, tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
) -> Command: ) -> Command:
""" """
Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and
@@ -56,10 +55,10 @@ async def fetch_naip_img(
ToolMessage( ToolMessage(
content="No NAIP imagery found for the specified area and date range.", content="No NAIP imagery found for the specified area and date range.",
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
) ),
], ],
"naip_png_path": None, "naip_png_path": None,
} },
) )
# --- 2. Load as xarray cube with odc.stac --- # --- 2. Load as xarray cube with odc.stac ---
@@ -81,14 +80,14 @@ async def fetch_naip_img(
ToolMessage( ToolMessage(
content="Unable to load NAIP RGB image, dataset has no time dimension", content="Unable to load NAIP RGB image, dataset has no time dimension",
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
) ),
], ],
"naip_png_path": None, "naip_png_path": None,
} },
) )
# --- 3. Build an RGB composite from the cube --- # --- 3. Build an RGB composite from the cube ---
# For the PNG, well just use the first time slice (you can swap in “latest” # For the PNG, we'll just use the first time slice (you can swap in “latest”
# or a temporal reduction if you prefer). # or a temporal reduction if you prefer).
red = ds["Red"].isel(time=0) red = ds["Red"].isel(time=0)
green = ds["Green"].isel(time=0) green = ds["Green"].isel(time=0)
@@ -121,8 +120,8 @@ async def fetch_naip_img(
ToolMessage( ToolMessage(
content=f"NAIP RGB image saved to {out_path.as_posix()}", content=f"NAIP RGB image saved to {out_path.as_posix()}",
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
) ),
], ],
"naip_png_path": out_path.as_posix(), "naip_png_path": out_path.as_posix(),
} },
) )
+5 -4
View File
@@ -4,11 +4,11 @@ from typing import Annotated
import duckdb import duckdb
from dotenv import load_dotenv from dotenv import load_dotenv
from geojson_pydantic import Feature
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId from langchain_core.tools.base import InjectedToolCallId
from langgraph.types import Command from langgraph.types import Command
from geojson_pydantic import Feature
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
@@ -34,7 +34,8 @@ def create_database_connection():
@tool @tool
async def get_place( async def get_place(
place_name: str, tool_call_id: Annotated[str, InjectedToolCallId] = "" place_name: str,
tool_call_id: Annotated[str, InjectedToolCallId] = "",
) -> Command: ) -> Command:
"""Get place location from Overture Maps based on user input place name.""" """Get place location from Overture Maps based on user input place name."""
@@ -65,7 +66,7 @@ async def get_place(
WHERE jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) > 0.5 WHERE jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) > 0.5
ORDER BY similarity_score DESC ORDER BY similarity_score DESC
LIMIT 1; LIMIT 1;
""" """,
).fetchall() ).fetchall()
db_connection.close() db_connection.close()
@@ -89,7 +90,7 @@ async def get_place(
ToolMessage( ToolMessage(
content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}", content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}",
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
) ),
], ],
}, },
) )
+10 -10
View File
@@ -1,14 +1,14 @@
"""Tools for summarizing satellite images using LLM-based analysis.""" """Tools for summarizing satellite images using LLM-based analysis."""
import os import os
from typing import Annotated, Optional from typing import Annotated
import dspy
from langchain_core.tools import tool
from langgraph.types import Command
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import InjectedToolCallId
import dotenv import dotenv
import dspy
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.types import Command
dotenv.load_dotenv() dotenv.load_dotenv()
@@ -68,7 +68,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent()
@tool @tool
async def summarize_sat_img( async def summarize_sat_img(
img_url: str, img_url: str,
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None, tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
) -> Command: ) -> Command:
"""Summarize the contents of a satellite image using an LLM. """Summarize the contents of a satellite image using an LLM.
@@ -96,7 +96,7 @@ async def summarize_sat_img(
content=message_content, content=message_content,
artifact=artifact, artifact=artifact,
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
) ),
] ],
} },
) )
+8 -6
View File
@@ -1,10 +1,11 @@
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from uuid import uuid4 from uuid import uuid4
from geo_assistant.api.app import app import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from geo_assistant.agent.graph import create_graph from geo_assistant.agent.graph import create_graph
from geo_assistant.api.app import app
@pytest_asyncio.fixture @pytest_asyncio.fixture
@@ -22,7 +23,8 @@ async def initialized_app():
async def test_hello_world(initialized_app): async def test_hello_world(initialized_app):
"""Hello world test for the API""" """Hello world test for the API"""
async with AsyncClient( async with AsyncClient(
transport=ASGITransport(app=initialized_app), base_url="http://test" transport=ASGITransport(app=initialized_app),
base_url="http://test",
) as client: ) as client:
thread_id = uuid4() thread_id = uuid4()
response = await client.post( response = await client.post(
@@ -33,7 +35,7 @@ async def test_hello_world(initialized_app):
{ {
"content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it", "content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it",
"type": "human", "type": "human",
} },
], ],
"place": None, "place": None,
"search_area": None, "search_area": None,
+4 -3
View File
@@ -1,8 +1,9 @@
from pytest import fixture
from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.tools.buffer import get_search_area
from geojson_pydantic import Feature, Point from geojson_pydantic import Feature, Point
from langchain_core.tools.base import ToolCall from langchain_core.tools.base import ToolCall
from pytest import fixture
from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.tools.buffer import get_search_area
@fixture @fixture
+5 -4
View File
@@ -1,7 +1,8 @@
import pytest
from pathlib import Path from pathlib import Path
from shapely.geometry import box, mapping
import pytest
from langchain_core.tools.base import ToolCall from langchain_core.tools.base import ToolCall
from shapely.geometry import box, mapping
from geo_assistant.tools.naip import fetch_naip_img from geo_assistant.tools.naip import fetch_naip_img
@@ -19,7 +20,7 @@ async def test_fetch_naip(tmp_path):
""" """
# Union Market coordinates from GeoNames: 38.90789, -76.99831 # Union Market coordinates from GeoNames: 38.90789, -76.99831
# N 38°5428 W 76°5954 # N 38°54'28" W 76°59'54"
# We'll use a small neighborhood AOI around that point. # We'll use a small neighborhood AOI around that point.
# :contentReference[oaicite:0]{index=0} # :contentReference[oaicite:0]{index=0}
lat = 38.90789 lat = 38.90789
@@ -44,7 +45,7 @@ async def test_fetch_naip(tmp_path):
id="test_tool_call_id", id="test_tool_call_id",
) )
# Call the actual tool no STAC / odc-stac mocking # Call the actual tool - no STAC / odc-stac mocking
result = await fetch_naip_img.ainvoke(tool_call["args"]) result = await fetch_naip_img.ainvoke(tool_call["args"])
# Basic sanity checks on result # Basic sanity checks on result
+2 -1
View File
@@ -1,4 +1,5 @@
import os import os
import pytest import pytest
from langchain_core.tools.base import ToolCall from langchain_core.tools.base import ToolCall
@@ -24,6 +25,6 @@ async def test_get_place():
type="tool_call", type="tool_call",
id="test_id", id="test_id",
args={"place_name": "Neighbourhood Cafe Lisbon"}, args={"place_name": "Neighbourhood Cafe Lisbon"},
) ),
) )
assert "place" in command.update assert "place" in command.update
+3 -2
View File
@@ -1,9 +1,10 @@
"""Tests for the satellite image summarization tool.""" """Tests for the satellite image summarization tool."""
import pytest
import uuid import uuid
import pytest
from langchain_core.tools.base import ToolCall from langchain_core.tools.base import ToolCall
from geo_assistant.tools.summarize import summarize_sat_img from geo_assistant.tools.summarize import summarize_sat_img
# Sample test data # Sample test data
@@ -24,7 +25,7 @@ def test_summarize_sat_img(img_url, summary):
type="tool_call", type="tool_call",
args={"img_url": img_url}, args={"img_url": img_url},
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
) ),
) )
print(command.update.get("messages")) print(command.update.get("messages"))