mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Render naip and summarize (#20)
* Intermediate * Fix naip geom handling * Fix imagery decoding for summary tool * Re enable xfail * Re enable xfail * Remove png references
This commit is contained in:
@@ -8,7 +8,7 @@ from pydantic import Field
|
||||
class GeoAssistantState(AgentState):
|
||||
place: NotRequired[Feature | None] = None
|
||||
search_area: NotRequired[Feature | None] = None
|
||||
naip_img_bytes: NotRequired[bytes | None] = Field(
|
||||
naip_img_bytes: NotRequired[str | None] = Field(
|
||||
default=None,
|
||||
description="Bytes of the saved NAIP RGB PNG image",
|
||||
description="Base 64 encoded bytes str of the saved NAIP RGB JPEG image",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import aclosing, asynccontextmanager
|
||||
@@ -96,7 +95,7 @@ async def stream_chat(
|
||||
state = GeoAssistantState(**payload)
|
||||
resp = ChatResponse(thread_id=str(thread_id), state=state)
|
||||
|
||||
line = json.dumps(resp.model_dump()) + "\n"
|
||||
line = resp.model_dump_json() + "\n"
|
||||
yield line.encode("utf-8")
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
@@ -68,8 +69,16 @@ def stream_chat(user_message: str):
|
||||
for key, value in state.items():
|
||||
if value and isinstance(value, dict) and value.get("type") == "Feature":
|
||||
geojson_features[key] = value
|
||||
# with st.chat_message("tool"):
|
||||
# st.code(json.dumps(value, indent=2), language="json")
|
||||
elif value and isinstance(value, str) and key == "naip_img_bytes":
|
||||
# Handle base64-encoded jpeg data
|
||||
try:
|
||||
img_bytes = base64.b64decode(value)
|
||||
with st.chat_message("tool"):
|
||||
st.image(img_bytes)
|
||||
except Exception:
|
||||
# If decoding fails, fall through to JSON display
|
||||
with st.chat_message("tool"):
|
||||
st.code(json.dumps(value, indent=2), language="json")
|
||||
elif value:
|
||||
with st.chat_message("tool"):
|
||||
st.code(json.dumps(value, indent=2), language="json")
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# tools/naip_mpc_tools.py
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated
|
||||
|
||||
import dotenv
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -10,41 +11,54 @@ import xarray as xr
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from langgraph.prebuilt import InjectedState
|
||||
from langgraph.types import Command
|
||||
from odc.stac import stac_load
|
||||
from pystac.extensions.raster import RasterBand
|
||||
from pystac_client import Client
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
DATA_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"
|
||||
# DATA_URL = "https://earth-search.aws.element84.com/v1"
|
||||
|
||||
|
||||
@tool("fetch_naip_img")
|
||||
async def fetch_naip_img(
|
||||
aoi_geojson: dict[str, Any],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
state: Annotated[GeoAssistantState, InjectedState],
|
||||
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||
) -> Command:
|
||||
"""
|
||||
Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and
|
||||
date range, load all matching items into an xarray data cube using odc-stac,
|
||||
and save a simple RGB composite as a PNG.
|
||||
and save a simple RGB composite as a JPEG.
|
||||
|
||||
Args:
|
||||
aoi_geojson: GeoJSON Polygon/MultiPolygon in EPSG:4326.
|
||||
start_date: Start date (YYYY-MM-DD).
|
||||
end_date: End date (YYYY-MM-DD).
|
||||
|
||||
"""
|
||||
if not state["search_area"]:
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content="No search area avilable yetmee",
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
],
|
||||
"naip_img_bytes": None,
|
||||
},
|
||||
)
|
||||
# --- 1. STAC search on Element84's EarthSearch API ---
|
||||
catalog = Client.open(DATA_URL)
|
||||
|
||||
search = catalog.search(
|
||||
collections=["naip"],
|
||||
intersects=aoi_geojson,
|
||||
intersects=state["search_area"].geometry,
|
||||
datetime=f"{start_date}/{end_date}",
|
||||
)
|
||||
|
||||
@@ -68,19 +82,19 @@ async def fetch_naip_img(
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
],
|
||||
"naip_png_path": None,
|
||||
"naip_img_bytes": None,
|
||||
},
|
||||
)
|
||||
|
||||
# --- 2. Load as xarray cube with odc.stac ---
|
||||
# NAIP in MPC: 4-band multi-band asset (R,G,B,NIR) in one asset named "image".
|
||||
# odc.stac exposes these as measurements 'red','green','blue','nir' for this collection
|
||||
|
||||
# Limit to first item for now
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
ds: xr.Dataset = stac_load(
|
||||
items,
|
||||
items[:1],
|
||||
bands=["red", "green", "blue"], # use only RGB
|
||||
geopolygon=aoi_geojson,
|
||||
geopolygon=state["search_area"].geometry,
|
||||
resolution=1.0, # NAIP native ~1 m
|
||||
executor=executor,
|
||||
crs=items[0].properties["proj:code"],
|
||||
@@ -117,7 +131,7 @@ async def fetch_naip_img(
|
||||
)
|
||||
|
||||
# --- 3. Build an RGB composite from the cube ---
|
||||
# For the PNG, we'll just use the first time slice (you can swap in “latest”
|
||||
# For the JPEG, we'll just use the first time slice (you can swap in “latest”
|
||||
# or a temporal reduction if you prefer).
|
||||
red = ds["red"].isel(time=0)
|
||||
green = ds["green"].isel(time=0)
|
||||
@@ -127,7 +141,7 @@ async def fetch_naip_img(
|
||||
rgb = xr.concat([red, green, blue], dim="band") # (band, y, x)
|
||||
rgb = rgb.transpose("y", "x", "band") # (y, x, band)
|
||||
|
||||
# Convert to uint8 for PNG with a simple contrast stretch.
|
||||
# Convert to uint8 for JPEG with a simple contrast stretch.
|
||||
arr = rgb.values.astype("float32")
|
||||
# Robust min/max to avoid a few hot pixels blowing out the stretch
|
||||
vmin = np.nanpercentile(arr, 2)
|
||||
@@ -138,21 +152,21 @@ async def fetch_naip_img(
|
||||
arr = np.clip((arr - vmin) / (vmax - vmin + 1e-6), 0, 1)
|
||||
arr_uint8 = (arr * 255).astype("uint8")
|
||||
|
||||
# --- 4. Save PNG ---
|
||||
# --- 4. Save image ---
|
||||
|
||||
buf = BytesIO()
|
||||
plt.imsave(buf, arr_uint8, format="png")
|
||||
plt.imsave(buf, arr_uint8, format="jpeg")
|
||||
buf.seek(0)
|
||||
img_bytes = buf.getvalue()
|
||||
img_base64 = base64.b64encode(buf.read()).decode("utf-8")
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content="NAIP RGB image fetched and encoded as PNG bytes.",
|
||||
content="NAIP RGB image fetched and encoded as JPEG bytes.",
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
],
|
||||
"naip_img_bytes": img_bytes,
|
||||
"naip_img_bytes": img_base64,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -8,8 +8,11 @@ import dspy
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
from langgraph.prebuilt import InjectedState
|
||||
from langgraph.types import Command
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
@@ -67,7 +70,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent()
|
||||
|
||||
@tool
|
||||
async def summarize_sat_img(
|
||||
img_url: str,
|
||||
state: Annotated[GeoAssistantState, InjectedState],
|
||||
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
|
||||
) -> Command:
|
||||
"""Summarize the contents of a satellite image using an LLM.
|
||||
@@ -82,21 +85,24 @@ async def summarize_sat_img(
|
||||
Raises:
|
||||
ValueError: If the image URL is invalid or the image cannot be processed
|
||||
"""
|
||||
if not img_url or not isinstance(img_url, str):
|
||||
raise ValueError("img_url must be a non-empty string")
|
||||
|
||||
if not state["naip_img_bytes"]:
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content="No NAIP image bytes available yet",
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
],
|
||||
},
|
||||
)
|
||||
img_url = f"data:image/jpeg;base64,{state['naip_img_bytes']}"
|
||||
summary = _SUMMARIZER_AGENT(img_url)
|
||||
message_content = summary.answer
|
||||
artifact = {"img_url": img_url}
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=message_content,
|
||||
artifact=artifact,
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
ToolMessage(content=message_content, tool_call_id=tool_call_id),
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
+3
-3
@@ -20,8 +20,7 @@ async def initialized_app():
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
async def test_hello_world(initialized_app):
|
||||
"""Hello world test for the API"""
|
||||
async def test_call_api(initialized_app):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=initialized_app),
|
||||
base_url="http://test",
|
||||
@@ -33,7 +32,7 @@ async def test_hello_world(initialized_app):
|
||||
"agent_state_input": {
|
||||
"messages": [
|
||||
{
|
||||
"content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it",
|
||||
"content": "Find The Whitney Hotel Boston and buffer 0.1km around it, then fetch the NAIP imagery for the area from 2021 and summarize the contents of the image.",
|
||||
"type": "human",
|
||||
},
|
||||
],
|
||||
@@ -43,6 +42,7 @@ async def test_hello_world(initialized_app):
|
||||
"thread_id": str(thread_id),
|
||||
},
|
||||
)
|
||||
print(response)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/x-ndjson; charset=utf-8"
|
||||
|
||||
@@ -17,7 +17,7 @@ def geo_assistant_fixture():
|
||||
place=place_geojson,
|
||||
search_area=None,
|
||||
messages=[],
|
||||
naip_png_path="path/to/naip.png",
|
||||
naip_img_bytes=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from types import NoneType
|
||||
|
||||
import pytest
|
||||
from geojson_pydantic import Feature
|
||||
from langchain_core.tools.base import ToolCall
|
||||
from shapely.geometry import box, mapping
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from geo_assistant.tools.naip import fetch_naip_img
|
||||
|
||||
|
||||
@@ -11,7 +13,7 @@ from geo_assistant.tools.naip import fetch_naip_img
|
||||
async def test_fetch_naip():
|
||||
"""
|
||||
Integration test: hit MPC STAC for NAIP around Union Market (DC),
|
||||
load imagery via odc-stac, and save an RGB PNG.
|
||||
load imagery via odc-stac, and save an RGB JPEG.
|
||||
|
||||
NOTE: This test requires:
|
||||
- Internet access (to reach Planetary Computer STAC + blobs)
|
||||
@@ -27,13 +29,13 @@ async def test_fetch_naip():
|
||||
# ~0.0001 degrees buffer in each direction
|
||||
aoi = box(lon - 0.0001, lat - 0.0001, lon + 0.0001, lat + 0.0001)
|
||||
aoi_geojson = mapping(aoi)
|
||||
|
||||
aoi_feature = Feature(type="Feature", geometry=aoi_geojson, properties={})
|
||||
tool_call = ToolCall(
|
||||
name="fetch_naip_img",
|
||||
args={
|
||||
"aoi_geojson": aoi_geojson,
|
||||
"start_date": "2021-01-01",
|
||||
"end_date": "2021-12-31",
|
||||
"state": GeoAssistantState(search_area=aoi_feature, messages=[]),
|
||||
},
|
||||
type="tool_call",
|
||||
id="test_tool_call_id",
|
||||
@@ -42,9 +44,9 @@ async def test_fetch_naip():
|
||||
# Call the actual tool - no STAC / odc-stac mocking
|
||||
result = await fetch_naip_img.ainvoke(tool_call)
|
||||
assert "naip_img_bytes" in result.update
|
||||
assert result.update["naip_img_bytes"] is not None, "Expected PNG bytes in result"
|
||||
assert isinstance(result.update["naip_img_bytes"], bytes)
|
||||
assert len(result.update["naip_img_bytes"]) > 1, "Expected non-empty PNG bytes"
|
||||
assert result.update["naip_img_bytes"] is not None, "Expected JPEG bytes in result"
|
||||
assert isinstance(result.update["naip_img_bytes"], str)
|
||||
assert len(result.update["naip_img_bytes"]) > 1, "Expected non-empty JPEG bytes"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -68,13 +70,13 @@ async def test_fetch_naip_too_large():
|
||||
# ~0.003 degrees buffer in each direction
|
||||
aoi = box(lon - 0.003, lat - 0.003, lon + 0.003, lat + 0.003)
|
||||
aoi_geojson = mapping(aoi)
|
||||
|
||||
aoi_feature = Feature(type="Feature", geometry=aoi_geojson, properties={})
|
||||
tool_call = ToolCall(
|
||||
name="fetch_naip_img",
|
||||
args={
|
||||
"aoi_geojson": aoi_geojson,
|
||||
"start_date": "2021-01-01",
|
||||
"end_date": "2021-12-31",
|
||||
"state": GeoAssistantState(search_area=aoi_feature, messages=[]),
|
||||
},
|
||||
type="tool_call",
|
||||
id="test_tool_call_id",
|
||||
@@ -83,5 +85,5 @@ async def test_fetch_naip_too_large():
|
||||
# Call the actual tool - no STAC / odc-stac mocking
|
||||
result = await fetch_naip_img.ainvoke(tool_call)
|
||||
assert "naip_img_bytes" in result.update
|
||||
assert result.update["naip_img_bytes"] is None, "Expected no PNG bytes in result"
|
||||
assert result.update["naip_img_bytes"] is None, "Expected no JPEG bytes in result"
|
||||
assert isinstance(result.update["naip_img_bytes"], NoneType)
|
||||
|
||||
@@ -1,29 +1,43 @@
|
||||
"""Tests for the satellite image summarization tool."""
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from langchain_core.tools.base import ToolCall
|
||||
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from geo_assistant.tools.summarize import summarize_sat_img
|
||||
|
||||
# Sample test data
|
||||
TEST_IMAGE_URL = "https://petapixel.com/assets/uploads/2022/08/French-Officials-Use-Satellite-Photos-and-AI-to-Spot-Unregistered-Pools-1536x806.jpg"
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"img_url,summary",
|
||||
[
|
||||
(TEST_IMAGE_URL, "building"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.xfail
|
||||
def test_summarize_sat_img(img_url, summary):
|
||||
command = summarize_sat_img.invoke(
|
||||
async def test_summarize_sat_img(img_url, summary):
|
||||
# Load the image from the supplied URL and encode it in base64
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
||||
}
|
||||
resp = requests.get(img_url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
img_base64 = base64.b64encode(resp.content).decode("utf-8")
|
||||
command = await summarize_sat_img.ainvoke(
|
||||
ToolCall(
|
||||
name="summarize_sat_img",
|
||||
type="tool_call",
|
||||
args={"img_url": img_url},
|
||||
args={
|
||||
"state": GeoAssistantState(naip_img_bytes=img_base64, messages=[]),
|
||||
"tool_call_id": str(uuid.uuid4()),
|
||||
},
|
||||
id=str(uuid.uuid4()),
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user