mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Use ruff rules COM, F, I, RUF, UP (#17)
* Use pyupgrade (UP) rule * Use pyflakes (F) rule * Use isort (I) rule * Use ruff-specific (RUF) rules * Use flake8-commas (COM) rule * Fix UP043 Unnecessary default type arguments
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import datetime
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langchain.agents import create_agent
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from geo_assistant.agent.llms import llm
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from geo_assistant.tools.buffer import get_search_area
|
||||
from geo_assistant.tools.naip import fetch_naip_img
|
||||
from geo_assistant.tools.overture import get_place
|
||||
from geo_assistant.tools.buffer import get_search_area
|
||||
from geo_assistant.tools.summarize import summarize_sat_img
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
@@ -36,7 +37,7 @@ async def create_graph():
|
||||
summarize_sat_img,
|
||||
],
|
||||
system_prompt=SYSTEM_PROMPT.format(
|
||||
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
),
|
||||
state_schema=GeoAssistantState,
|
||||
checkpointer=checkpointer,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
# Load environment variables from env file
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from langchain.agents import AgentState
|
||||
from typing import NotRequired
|
||||
|
||||
from geojson_pydantic import Feature
|
||||
from typing_extensions import NotRequired
|
||||
from langchain.agents import AgentState
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
@@ -8,5 +9,6 @@ class GeoAssistantState(AgentState):
|
||||
place: NotRequired[Feature | None] = None
|
||||
search_area: NotRequired[Feature | None] = None
|
||||
naip_png_path: NotRequired[str | None] = Field(
|
||||
default=None, description="Path to the saved NAIP RGB PNG image"
|
||||
default=None,
|
||||
description="Path to the saved NAIP RGB PNG image",
|
||||
)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import aclosing, asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
import logging
|
||||
from pydantic import UUID4
|
||||
|
||||
from geo_assistant.agent.graph import create_graph
|
||||
@@ -43,11 +44,11 @@ async def stream_chat(
|
||||
thread_id: UUID4,
|
||||
chatbot: Any,
|
||||
request: Request,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
config: Dict[str, Any] = {
|
||||
) -> AsyncGenerator[bytes]:
|
||||
config: dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": str(thread_id),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
state_updates = {}
|
||||
@@ -69,7 +70,7 @@ async def stream_chat(
|
||||
{
|
||||
"content": f"Manually selected data for field {key}: {description}",
|
||||
"type": "human",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Add UI messages to the existing messages if they exist
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import folium
|
||||
import httpx
|
||||
import streamlit as st
|
||||
import streamlit.components.v1 as components
|
||||
import folium
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
@@ -137,7 +137,10 @@ def stream_chat(user_message: str):
|
||||
# Fit map to bounds if we have coordinates
|
||||
if all_lons and all_lats:
|
||||
m.fit_bounds(
|
||||
[[min(all_lats), min(all_lons)], [max(all_lats), max(all_lons)]]
|
||||
[
|
||||
[min(all_lats), min(all_lons)],
|
||||
[max(all_lats), max(all_lons)],
|
||||
],
|
||||
)
|
||||
|
||||
# Display the map
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from typing import Annotated
|
||||
|
||||
import geopandas as gpd
|
||||
from langgraph.types import Command
|
||||
from langgraph.prebuilt import InjectedState
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from geojson_pydantic import Feature
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from langgraph.prebuilt import InjectedState
|
||||
from langgraph.types import Command
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from geojson_pydantic import Feature
|
||||
|
||||
|
||||
@tool
|
||||
@@ -26,9 +28,9 @@ async def get_search_area(
|
||||
ToolMessage(
|
||||
content="No place defined in the agent state to create a search area around.",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Convert GeoJSON feature to GeoDataFrame
|
||||
@@ -38,7 +40,7 @@ async def get_search_area(
|
||||
gdf_m = gdf.to_crs(epsg=3857) # latlon to Web Mercator for meter-based buffering
|
||||
|
||||
gdf_m["geometry"] = gdf_m["geometry"].buffer(
|
||||
buffer_size_km * 1000
|
||||
buffer_size_km * 1000,
|
||||
) # Buffer in meters
|
||||
gdf = gdf_m.to_crs(epsg=4326) # Back to WGS84
|
||||
|
||||
@@ -46,7 +48,7 @@ async def get_search_area(
|
||||
if len(gdf) != 1:
|
||||
raise ValueError(
|
||||
f"{len(gdf)} features found after buffer operation, should be just 1. "
|
||||
"Was a Multi-Point/LineString/Polygon geometry passed in?"
|
||||
"Was a Multi-Point/LineString/Polygon geometry passed in?",
|
||||
)
|
||||
buffer_feature = Feature(
|
||||
type="Feature",
|
||||
@@ -61,7 +63,7 @@ async def get_search_area(
|
||||
ToolMessage(
|
||||
content=f"Created search area geometry buffer of {buffer_size_km} km around the place.",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
# tools/naip_mpc_tools.py
|
||||
from typing import Dict, Any, Optional, Annotated
|
||||
from pathlib import Path
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
import matplotlib.pyplot as plt
|
||||
from langchain_core.tools import tool
|
||||
from pystac_client import Client
|
||||
from odc.stac import stac_load
|
||||
from langgraph.types import Command
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
import dotenv
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from langgraph.types import Command
|
||||
from odc.stac import stac_load
|
||||
from pystac_client import Client
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@@ -23,10 +22,10 @@ E84_STAC_URL = "https://earth-search.aws.element84.com/v1"
|
||||
|
||||
@tool("fetch_naip_img")
|
||||
async def fetch_naip_img(
|
||||
aoi_geojson: Dict[str, Any],
|
||||
aoi_geojson: dict[str, Any],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
|
||||
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||
) -> Command:
|
||||
"""
|
||||
Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and
|
||||
@@ -56,10 +55,10 @@ async def fetch_naip_img(
|
||||
ToolMessage(
|
||||
content="No NAIP imagery found for the specified area and date range.",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
"naip_png_path": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# --- 2. Load as xarray cube with odc.stac ---
|
||||
@@ -81,14 +80,14 @@ async def fetch_naip_img(
|
||||
ToolMessage(
|
||||
content="Unable to load NAIP RGB image, dataset has no time dimension",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
"naip_png_path": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# --- 3. Build an RGB composite from the cube ---
|
||||
# For the PNG, we’ll just use the first time slice (you can swap in “latest”
|
||||
# For the PNG, we'll just use the first time slice (you can swap in “latest”
|
||||
# or a temporal reduction if you prefer).
|
||||
red = ds["Red"].isel(time=0)
|
||||
green = ds["Green"].isel(time=0)
|
||||
@@ -121,8 +120,8 @@ async def fetch_naip_img(
|
||||
ToolMessage(
|
||||
content=f"NAIP RGB image saved to {out_path.as_posix()}",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
"naip_png_path": out_path.as_posix(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -4,11 +4,11 @@ from typing import Annotated
|
||||
|
||||
import duckdb
|
||||
from dotenv import load_dotenv
|
||||
from geojson_pydantic import Feature
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from langgraph.types import Command
|
||||
from geojson_pydantic import Feature
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
@@ -34,7 +34,8 @@ def create_database_connection():
|
||||
|
||||
@tool
|
||||
async def get_place(
|
||||
place_name: str, tool_call_id: Annotated[str, InjectedToolCallId] = ""
|
||||
place_name: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||
) -> Command:
|
||||
"""Get place location from Overture Maps based on user input place name."""
|
||||
|
||||
@@ -65,7 +66,7 @@ async def get_place(
|
||||
WHERE jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) > 0.5
|
||||
ORDER BY similarity_score DESC
|
||||
LIMIT 1;
|
||||
"""
|
||||
""",
|
||||
).fetchall()
|
||||
|
||||
db_connection.close()
|
||||
@@ -89,7 +90,7 @@ async def get_place(
|
||||
ToolMessage(
|
||||
content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}",
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Tools for summarizing satellite images using LLM-based analysis."""
|
||||
|
||||
import os
|
||||
from typing import Annotated, Optional
|
||||
import dspy
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from typing import Annotated
|
||||
|
||||
import dotenv
|
||||
import dspy
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from langgraph.types import Command
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@@ -68,7 +68,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent()
|
||||
@tool
|
||||
async def summarize_sat_img(
|
||||
img_url: str,
|
||||
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
|
||||
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||
) -> Command:
|
||||
"""Summarize the contents of a satellite image using an LLM.
|
||||
|
||||
@@ -96,7 +96,7 @@ async def summarize_sat_img(
|
||||
content=message_content,
|
||||
artifact=artifact,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
]
|
||||
}
|
||||
),
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user