mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-15 15:31:02 +02:00
Use ruff rules COM, F, I, RUF, UP (#17)
* Use pyupgrade (UP) rule * Use pyflakes (F) rule * Use isort (I) rule * Use ruff-specific (RUF) rules * Use flake8-commas (COM) rule * Fix UP043 Unnecessary default type arguments
This commit is contained in:
@@ -47,3 +47,12 @@ packages = ["src/geo_assistant"]
|
|||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--color=yes"
|
addopts = "--color=yes"
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"COM", # flake8-commas
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"RUF", # ruff-specific
|
||||||
|
"UP", # pyupgrade
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from geo_assistant.agent.state import GeoAssistantState
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
from geo_assistant.agent.llms import llm
|
from geo_assistant.agent.llms import llm
|
||||||
|
from geo_assistant.agent.state import GeoAssistantState
|
||||||
|
from geo_assistant.tools.buffer import get_search_area
|
||||||
from geo_assistant.tools.naip import fetch_naip_img
|
from geo_assistant.tools.naip import fetch_naip_img
|
||||||
from geo_assistant.tools.overture import get_place
|
from geo_assistant.tools.overture import get_place
|
||||||
from geo_assistant.tools.buffer import get_search_area
|
|
||||||
from geo_assistant.tools.summarize import summarize_sat_img
|
from geo_assistant.tools.summarize import summarize_sat_img
|
||||||
|
|
||||||
SYSTEM_PROMPT = """
|
SYSTEM_PROMPT = """
|
||||||
@@ -36,7 +37,7 @@ async def create_graph():
|
|||||||
summarize_sat_img,
|
summarize_sat_img,
|
||||||
],
|
],
|
||||||
system_prompt=SYSTEM_PROMPT.format(
|
system_prompt=SYSTEM_PROMPT.format(
|
||||||
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
),
|
),
|
||||||
state_schema=GeoAssistantState,
|
state_schema=GeoAssistantState,
|
||||||
checkpointer=checkpointer,
|
checkpointer=checkpointer,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
from langchain_ollama import ChatOllama
|
from langchain_ollama import ChatOllama
|
||||||
|
|
||||||
# Load environment variables from env file
|
# Load environment variables from env file
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from langchain.agents import AgentState
|
from typing import NotRequired
|
||||||
|
|
||||||
from geojson_pydantic import Feature
|
from geojson_pydantic import Feature
|
||||||
from typing_extensions import NotRequired
|
from langchain.agents import AgentState
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
@@ -8,5 +9,6 @@ class GeoAssistantState(AgentState):
|
|||||||
place: NotRequired[Feature | None] = None
|
place: NotRequired[Feature | None] = None
|
||||||
search_area: NotRequired[Feature | None] = None
|
search_area: NotRequired[Feature | None] = None
|
||||||
naip_png_path: NotRequired[str | None] = Field(
|
naip_png_path: NotRequired[str | None] = Field(
|
||||||
default=None, description="Path to the saved NAIP RGB PNG image"
|
default=None,
|
||||||
|
description="Path to the saved NAIP RGB PNG image",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import aclosing, asynccontextmanager
|
from contextlib import aclosing, asynccontextmanager
|
||||||
from typing import Any, AsyncGenerator, Dict
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import logging
|
|
||||||
from pydantic import UUID4
|
from pydantic import UUID4
|
||||||
|
|
||||||
from geo_assistant.agent.graph import create_graph
|
from geo_assistant.agent.graph import create_graph
|
||||||
@@ -43,11 +44,11 @@ async def stream_chat(
|
|||||||
thread_id: UUID4,
|
thread_id: UUID4,
|
||||||
chatbot: Any,
|
chatbot: Any,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes]:
|
||||||
config: Dict[str, Any] = {
|
config: dict[str, Any] = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
"thread_id": str(thread_id),
|
"thread_id": str(thread_id),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
state_updates = {}
|
state_updates = {}
|
||||||
@@ -69,7 +70,7 @@ async def stream_chat(
|
|||||||
{
|
{
|
||||||
"content": f"Manually selected data for field {key}: {description}",
|
"content": f"Manually selected data for field {key}: {description}",
|
||||||
"type": "human",
|
"type": "human",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add UI messages to the existing messages if they exist
|
# Add UI messages to the existing messages if they exist
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from geo_assistant.agent.state import GeoAssistantState
|
from geo_assistant.agent.state import GeoAssistantState
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ import json
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
import folium
|
||||||
import httpx
|
import httpx
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import streamlit.components.v1 as components
|
import streamlit.components.v1 as components
|
||||||
import folium
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Load environment variables from .env file
|
# Load environment variables from .env file
|
||||||
@@ -137,7 +137,10 @@ def stream_chat(user_message: str):
|
|||||||
# Fit map to bounds if we have coordinates
|
# Fit map to bounds if we have coordinates
|
||||||
if all_lons and all_lats:
|
if all_lons and all_lats:
|
||||||
m.fit_bounds(
|
m.fit_bounds(
|
||||||
[[min(all_lats), min(all_lons)], [max(all_lats), max(all_lons)]]
|
[
|
||||||
|
[min(all_lats), min(all_lons)],
|
||||||
|
[max(all_lats), max(all_lons)],
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display the map
|
# Display the map
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
import geopandas as gpd
|
import geopandas as gpd
|
||||||
from langgraph.types import Command
|
from geojson_pydantic import Feature
|
||||||
from langgraph.prebuilt import InjectedState
|
|
||||||
from langchain_core.tools.base import InjectedToolCallId
|
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from typing import Annotated
|
from langchain_core.tools.base import InjectedToolCallId
|
||||||
|
from langgraph.prebuilt import InjectedState
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
from geo_assistant.agent.state import GeoAssistantState
|
from geo_assistant.agent.state import GeoAssistantState
|
||||||
from geojson_pydantic import Feature
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -26,9 +28,9 @@ async def get_search_area(
|
|||||||
ToolMessage(
|
ToolMessage(
|
||||||
content="No place defined in the agent state to create a search area around.",
|
content="No place defined in the agent state to create a search area around.",
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert GeoJSON feature to GeoDataFrame
|
# Convert GeoJSON feature to GeoDataFrame
|
||||||
@@ -38,7 +40,7 @@ async def get_search_area(
|
|||||||
gdf_m = gdf.to_crs(epsg=3857) # latlon to Web Mercator for meter-based buffering
|
gdf_m = gdf.to_crs(epsg=3857) # latlon to Web Mercator for meter-based buffering
|
||||||
|
|
||||||
gdf_m["geometry"] = gdf_m["geometry"].buffer(
|
gdf_m["geometry"] = gdf_m["geometry"].buffer(
|
||||||
buffer_size_km * 1000
|
buffer_size_km * 1000,
|
||||||
) # Buffer in meters
|
) # Buffer in meters
|
||||||
gdf = gdf_m.to_crs(epsg=4326) # Back to WGS84
|
gdf = gdf_m.to_crs(epsg=4326) # Back to WGS84
|
||||||
|
|
||||||
@@ -46,7 +48,7 @@ async def get_search_area(
|
|||||||
if len(gdf) != 1:
|
if len(gdf) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{len(gdf)} features found after buffer operation, should be just 1. "
|
f"{len(gdf)} features found after buffer operation, should be just 1. "
|
||||||
"Was a Multi-Point/LineString/Polygon geometry passed in?"
|
"Was a Multi-Point/LineString/Polygon geometry passed in?",
|
||||||
)
|
)
|
||||||
buffer_feature = Feature(
|
buffer_feature = Feature(
|
||||||
type="Feature",
|
type="Feature",
|
||||||
@@ -61,7 +63,7 @@ async def get_search_area(
|
|||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=f"Created search area geometry buffer of {buffer_size_km} km around the place.",
|
content=f"Created search area geometry buffer of {buffer_size_km} km around the place.",
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
# tools/naip_mpc_tools.py
|
# tools/naip_mpc_tools.py
|
||||||
from typing import Dict, Any, Optional, Annotated
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
import numpy as np
|
from pathlib import Path
|
||||||
import xarray as xr
|
from typing import Annotated, Any
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from pystac_client import Client
|
|
||||||
from odc.stac import stac_load
|
|
||||||
from langgraph.types import Command
|
|
||||||
from langchain_core.messages import ToolMessage
|
|
||||||
from langchain_core.tools.base import InjectedToolCallId
|
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_core.tools.base import InjectedToolCallId
|
||||||
|
from langgraph.types import Command
|
||||||
|
from odc.stac import stac_load
|
||||||
|
from pystac_client import Client
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@@ -23,10 +22,10 @@ E84_STAC_URL = "https://earth-search.aws.element84.com/v1"
|
|||||||
|
|
||||||
@tool("fetch_naip_img")
|
@tool("fetch_naip_img")
|
||||||
async def fetch_naip_img(
|
async def fetch_naip_img(
|
||||||
aoi_geojson: Dict[str, Any],
|
aoi_geojson: dict[str, Any],
|
||||||
start_date: str,
|
start_date: str,
|
||||||
end_date: str,
|
end_date: str,
|
||||||
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
|
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||||
) -> Command:
|
) -> Command:
|
||||||
"""
|
"""
|
||||||
Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and
|
Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and
|
||||||
@@ -56,10 +55,10 @@ async def fetch_naip_img(
|
|||||||
ToolMessage(
|
ToolMessage(
|
||||||
content="No NAIP imagery found for the specified area and date range.",
|
content="No NAIP imagery found for the specified area and date range.",
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
"naip_png_path": None,
|
"naip_png_path": None,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 2. Load as xarray cube with odc.stac ---
|
# --- 2. Load as xarray cube with odc.stac ---
|
||||||
@@ -81,14 +80,14 @@ async def fetch_naip_img(
|
|||||||
ToolMessage(
|
ToolMessage(
|
||||||
content="Unable to load NAIP RGB image, dataset has no time dimension",
|
content="Unable to load NAIP RGB image, dataset has no time dimension",
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
"naip_png_path": None,
|
"naip_png_path": None,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 3. Build an RGB composite from the cube ---
|
# --- 3. Build an RGB composite from the cube ---
|
||||||
# For the PNG, we’ll just use the first time slice (you can swap in “latest”
|
# For the PNG, we'll just use the first time slice (you can swap in “latest”
|
||||||
# or a temporal reduction if you prefer).
|
# or a temporal reduction if you prefer).
|
||||||
red = ds["Red"].isel(time=0)
|
red = ds["Red"].isel(time=0)
|
||||||
green = ds["Green"].isel(time=0)
|
green = ds["Green"].isel(time=0)
|
||||||
@@ -121,8 +120,8 @@ async def fetch_naip_img(
|
|||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=f"NAIP RGB image saved to {out_path.as_posix()}",
|
content=f"NAIP RGB image saved to {out_path.as_posix()}",
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
"naip_png_path": out_path.as_posix(),
|
"naip_png_path": out_path.as_posix(),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ from typing import Annotated
|
|||||||
|
|
||||||
import duckdb
|
import duckdb
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from geojson_pydantic import Feature
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langchain_core.tools.base import InjectedToolCallId
|
from langchain_core.tools.base import InjectedToolCallId
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
from geojson_pydantic import Feature
|
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -34,7 +34,8 @@ def create_database_connection():
|
|||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def get_place(
|
async def get_place(
|
||||||
place_name: str, tool_call_id: Annotated[str, InjectedToolCallId] = ""
|
place_name: str,
|
||||||
|
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||||
) -> Command:
|
) -> Command:
|
||||||
"""Get place location from Overture Maps based on user input place name."""
|
"""Get place location from Overture Maps based on user input place name."""
|
||||||
|
|
||||||
@@ -65,7 +66,7 @@ async def get_place(
|
|||||||
WHERE jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) > 0.5
|
WHERE jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) > 0.5
|
||||||
ORDER BY similarity_score DESC
|
ORDER BY similarity_score DESC
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
"""
|
""",
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
db_connection.close()
|
db_connection.close()
|
||||||
@@ -89,7 +90,7 @@ async def get_place(
|
|||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}",
|
content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}",
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
"""Tools for summarizing satellite images using LLM-based analysis."""
|
"""Tools for summarizing satellite images using LLM-based analysis."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated
|
||||||
import dspy
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from langgraph.types import Command
|
|
||||||
from langchain_core.messages import ToolMessage
|
|
||||||
from langchain_core.tools.base import InjectedToolCallId
|
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
|
import dspy
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_core.tools.base import InjectedToolCallId
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@@ -68,7 +68,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent()
|
|||||||
@tool
|
@tool
|
||||||
async def summarize_sat_img(
|
async def summarize_sat_img(
|
||||||
img_url: str,
|
img_url: str,
|
||||||
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
|
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||||
) -> Command:
|
) -> Command:
|
||||||
"""Summarize the contents of a satellite image using an LLM.
|
"""Summarize the contents of a satellite image using an LLM.
|
||||||
|
|
||||||
@@ -96,7 +96,7 @@ async def summarize_sat_img(
|
|||||||
content=message_content,
|
content=message_content,
|
||||||
artifact=artifact,
|
artifact=artifact,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
)
|
),
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
+8
-6
@@ -1,10 +1,11 @@
|
|||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
from httpx import AsyncClient, ASGITransport
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from geo_assistant.api.app import app
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
from geo_assistant.agent.graph import create_graph
|
from geo_assistant.agent.graph import create_graph
|
||||||
|
from geo_assistant.api.app import app
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
@@ -22,7 +23,8 @@ async def initialized_app():
|
|||||||
async def test_hello_world(initialized_app):
|
async def test_hello_world(initialized_app):
|
||||||
"""Hello world test for the API"""
|
"""Hello world test for the API"""
|
||||||
async with AsyncClient(
|
async with AsyncClient(
|
||||||
transport=ASGITransport(app=initialized_app), base_url="http://test"
|
transport=ASGITransport(app=initialized_app),
|
||||||
|
base_url="http://test",
|
||||||
) as client:
|
) as client:
|
||||||
thread_id = uuid4()
|
thread_id = uuid4()
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
@@ -33,7 +35,7 @@ async def test_hello_world(initialized_app):
|
|||||||
{
|
{
|
||||||
"content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it",
|
"content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it",
|
||||||
"type": "human",
|
"type": "human",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
"place": None,
|
"place": None,
|
||||||
"search_area": None,
|
"search_area": None,
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from pytest import fixture
|
|
||||||
from geo_assistant.agent.state import GeoAssistantState
|
|
||||||
from geo_assistant.tools.buffer import get_search_area
|
|
||||||
from geojson_pydantic import Feature, Point
|
from geojson_pydantic import Feature, Point
|
||||||
from langchain_core.tools.base import ToolCall
|
from langchain_core.tools.base import ToolCall
|
||||||
|
from pytest import fixture
|
||||||
|
|
||||||
|
from geo_assistant.agent.state import GeoAssistantState
|
||||||
|
from geo_assistant.tools.buffer import get_search_area
|
||||||
|
|
||||||
|
|
||||||
@fixture
|
@fixture
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import pytest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shapely.geometry import box, mapping
|
|
||||||
|
import pytest
|
||||||
from langchain_core.tools.base import ToolCall
|
from langchain_core.tools.base import ToolCall
|
||||||
|
from shapely.geometry import box, mapping
|
||||||
|
|
||||||
from geo_assistant.tools.naip import fetch_naip_img
|
from geo_assistant.tools.naip import fetch_naip_img
|
||||||
|
|
||||||
@@ -19,7 +20,7 @@ async def test_fetch_naip(tmp_path):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
||||||
# N 38°54′28″ W 76°59′54″
|
# N 38°54'28" W 76°59'54"
|
||||||
# We'll use a small neighborhood AOI around that point.
|
# We'll use a small neighborhood AOI around that point.
|
||||||
# :contentReference[oaicite:0]{index=0}
|
# :contentReference[oaicite:0]{index=0}
|
||||||
lat = 38.90789
|
lat = 38.90789
|
||||||
@@ -44,7 +45,7 @@ async def test_fetch_naip(tmp_path):
|
|||||||
id="test_tool_call_id",
|
id="test_tool_call_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the actual tool – no STAC / odc-stac mocking
|
# Call the actual tool - no STAC / odc-stac mocking
|
||||||
result = await fetch_naip_img.ainvoke(tool_call["args"])
|
result = await fetch_naip_img.ainvoke(tool_call["args"])
|
||||||
|
|
||||||
# Basic sanity checks on result
|
# Basic sanity checks on result
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.tools.base import ToolCall
|
from langchain_core.tools.base import ToolCall
|
||||||
|
|
||||||
@@ -24,6 +25,6 @@ async def test_get_place():
|
|||||||
type="tool_call",
|
type="tool_call",
|
||||||
id="test_id",
|
id="test_id",
|
||||||
args={"place_name": "Neighbourhood Cafe Lisbon"},
|
args={"place_name": "Neighbourhood Cafe Lisbon"},
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
assert "place" in command.update
|
assert "place" in command.update
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Tests for the satellite image summarization tool."""
|
"""Tests for the satellite image summarization tool."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain_core.tools.base import ToolCall
|
from langchain_core.tools.base import ToolCall
|
||||||
|
|
||||||
from geo_assistant.tools.summarize import summarize_sat_img
|
from geo_assistant.tools.summarize import summarize_sat_img
|
||||||
|
|
||||||
# Sample test data
|
# Sample test data
|
||||||
@@ -24,7 +25,7 @@ def test_summarize_sat_img(img_url, summary):
|
|||||||
type="tool_call",
|
type="tool_call",
|
||||||
args={"img_url": img_url},
|
args={"img_url": img_url},
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
print(command.update.get("messages"))
|
print(command.update.get("messages"))
|
||||||
|
|||||||
Reference in New Issue
Block a user