Render naip and summarize (#20)

* Intermediate

* Fix naip geom handling

* Fix imagery decoding for summary tool

* Re enable xfail

* Re enable xfail

* Remove png references
This commit is contained in:
Daniel Wiesmann
2025-12-05 15:39:16 +00:00
committed by GitHub
parent dddac818ea
commit 9b202c504e
9 changed files with 95 additions and 51 deletions
+2 -2
View File
@@ -8,7 +8,7 @@ from pydantic import Field
class GeoAssistantState(AgentState): class GeoAssistantState(AgentState):
place: NotRequired[Feature | None] = None place: NotRequired[Feature | None] = None
search_area: NotRequired[Feature | None] = None search_area: NotRequired[Feature | None] = None
naip_img_bytes: NotRequired[bytes | None] = Field( naip_img_bytes: NotRequired[str | None] = Field(
default=None, default=None,
description="Bytes of the saved NAIP RGB PNG image", description="Base 64 encoded bytes str of the saved NAIP RGB JPEG image",
) )
+1 -2
View File
@@ -1,4 +1,3 @@
import json
import logging import logging
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import aclosing, asynccontextmanager from contextlib import aclosing, asynccontextmanager
@@ -96,7 +95,7 @@ async def stream_chat(
state = GeoAssistantState(**payload) state = GeoAssistantState(**payload)
resp = ChatResponse(thread_id=str(thread_id), state=state) resp = ChatResponse(thread_id=str(thread_id), state=state)
line = json.dumps(resp.model_dump()) + "\n" line = resp.model_dump_json() + "\n"
yield line.encode("utf-8") yield line.encode("utf-8")
+11 -2
View File
@@ -1,3 +1,4 @@
import base64
import json import json
import os import os
import uuid import uuid
@@ -68,8 +69,16 @@ def stream_chat(user_message: str):
for key, value in state.items(): for key, value in state.items():
if value and isinstance(value, dict) and value.get("type") == "Feature": if value and isinstance(value, dict) and value.get("type") == "Feature":
geojson_features[key] = value geojson_features[key] = value
# with st.chat_message("tool"): elif value and isinstance(value, str) and key == "naip_img_bytes":
# st.code(json.dumps(value, indent=2), language="json") # Handle base64-encoded jpeg data
try:
img_bytes = base64.b64decode(value)
with st.chat_message("tool"):
st.image(img_bytes)
except Exception:
# If decoding fails, fall through to JSON display
with st.chat_message("tool"):
st.code(json.dumps(value, indent=2), language="json")
elif value: elif value:
with st.chat_message("tool"): with st.chat_message("tool"):
st.code(json.dumps(value, indent=2), language="json") st.code(json.dumps(value, indent=2), language="json")
+31 -17
View File
@@ -1,7 +1,8 @@
# tools/naip_mpc_tools.py # tools/naip_mpc_tools.py
import base64
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from io import BytesIO from io import BytesIO
from typing import Annotated, Any from typing import Annotated
import dotenv import dotenv
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -10,41 +11,54 @@ import xarray as xr
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId from langchain_core.tools.base import InjectedToolCallId
from langgraph.prebuilt import InjectedState
from langgraph.types import Command from langgraph.types import Command
from odc.stac import stac_load from odc.stac import stac_load
from pystac.extensions.raster import RasterBand from pystac.extensions.raster import RasterBand
from pystac_client import Client from pystac_client import Client
from geo_assistant.agent.state import GeoAssistantState
dotenv.load_dotenv() dotenv.load_dotenv()
DATA_URL = "https://planetarycomputer.microsoft.com/api/stac/v1" DATA_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"
# DATA_URL = "https://earth-search.aws.element84.com/v1"
@tool("fetch_naip_img") @tool("fetch_naip_img")
async def fetch_naip_img( async def fetch_naip_img(
aoi_geojson: dict[str, Any],
start_date: str, start_date: str,
end_date: str, end_date: str,
state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str | None, InjectedToolCallId] = None, tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
) -> Command: ) -> Command:
""" """
Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and
date range, load all matching items into an xarray data cube using odc-stac, date range, load all matching items into an xarray data cube using odc-stac,
and save a simple RGB composite as a PNG. and save a simple RGB composite as a JPEG.
Args: Args:
aoi_geojson: GeoJSON Polygon/MultiPolygon in EPSG:4326.
start_date: Start date (YYYY-MM-DD). start_date: Start date (YYYY-MM-DD).
end_date: End date (YYYY-MM-DD). end_date: End date (YYYY-MM-DD).
""" """
if not state["search_area"]:
return Command(
update={
"messages": [
ToolMessage(
content="No search area avilable yetmee",
tool_call_id=tool_call_id,
),
],
"naip_img_bytes": None,
},
)
# --- 1. STAC search on Element84's EarthSearch API --- # --- 1. STAC search on Element84's EarthSearch API ---
catalog = Client.open(DATA_URL) catalog = Client.open(DATA_URL)
search = catalog.search( search = catalog.search(
collections=["naip"], collections=["naip"],
intersects=aoi_geojson, intersects=state["search_area"].geometry,
datetime=f"{start_date}/{end_date}", datetime=f"{start_date}/{end_date}",
) )
@@ -68,19 +82,19 @@ async def fetch_naip_img(
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
), ),
], ],
"naip_png_path": None, "naip_img_bytes": None,
}, },
) )
# --- 2. Load as xarray cube with odc.stac --- # --- 2. Load as xarray cube with odc.stac ---
# NAIP in MPC: 4-band multi-band asset (R,G,B,NIR) in one asset named "image". # NAIP in MPC: 4-band multi-band asset (R,G,B,NIR) in one asset named "image".
# odc.stac exposes these as measurements 'red','green','blue','nir' for this collection # odc.stac exposes these as measurements 'red','green','blue','nir' for this collection
# Limit to first item for now
with ThreadPoolExecutor(max_workers=5) as executor: with ThreadPoolExecutor(max_workers=5) as executor:
ds: xr.Dataset = stac_load( ds: xr.Dataset = stac_load(
items, items[:1],
bands=["red", "green", "blue"], # use only RGB bands=["red", "green", "blue"], # use only RGB
geopolygon=aoi_geojson, geopolygon=state["search_area"].geometry,
resolution=1.0, # NAIP native ~1 m resolution=1.0, # NAIP native ~1 m
executor=executor, executor=executor,
crs=items[0].properties["proj:code"], crs=items[0].properties["proj:code"],
@@ -117,7 +131,7 @@ async def fetch_naip_img(
) )
# --- 3. Build an RGB composite from the cube --- # --- 3. Build an RGB composite from the cube ---
# For the PNG, we'll just use the first time slice (you can swap in “latest” # For the JPEG, we'll just use the first time slice (you can swap in “latest”
# or a temporal reduction if you prefer). # or a temporal reduction if you prefer).
red = ds["red"].isel(time=0) red = ds["red"].isel(time=0)
green = ds["green"].isel(time=0) green = ds["green"].isel(time=0)
@@ -127,7 +141,7 @@ async def fetch_naip_img(
rgb = xr.concat([red, green, blue], dim="band") # (band, y, x) rgb = xr.concat([red, green, blue], dim="band") # (band, y, x)
rgb = rgb.transpose("y", "x", "band") # (y, x, band) rgb = rgb.transpose("y", "x", "band") # (y, x, band)
# Convert to uint8 for PNG with a simple contrast stretch. # Convert to uint8 for JPEG with a simple contrast stretch.
arr = rgb.values.astype("float32") arr = rgb.values.astype("float32")
# Robust min/max to avoid a few hot pixels blowing out the stretch # Robust min/max to avoid a few hot pixels blowing out the stretch
vmin = np.nanpercentile(arr, 2) vmin = np.nanpercentile(arr, 2)
@@ -138,21 +152,21 @@ async def fetch_naip_img(
arr = np.clip((arr - vmin) / (vmax - vmin + 1e-6), 0, 1) arr = np.clip((arr - vmin) / (vmax - vmin + 1e-6), 0, 1)
arr_uint8 = (arr * 255).astype("uint8") arr_uint8 = (arr * 255).astype("uint8")
# --- 4. Save PNG --- # --- 4. Save image ---
buf = BytesIO() buf = BytesIO()
plt.imsave(buf, arr_uint8, format="png") plt.imsave(buf, arr_uint8, format="jpeg")
buf.seek(0) buf.seek(0)
img_bytes = buf.getvalue() img_base64 = base64.b64encode(buf.read()).decode("utf-8")
return Command( return Command(
update={ update={
"messages": [ "messages": [
ToolMessage( ToolMessage(
content="NAIP RGB image fetched and encoded as PNG bytes.", content="NAIP RGB image fetched and encoded as JPEG bytes.",
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
), ),
], ],
"naip_img_bytes": img_bytes, "naip_img_bytes": img_base64,
}, },
) )
+16 -10
View File
@@ -8,8 +8,11 @@ import dspy
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId from langchain_core.tools.base import InjectedToolCallId
from langgraph.prebuilt import InjectedState
from langgraph.types import Command from langgraph.types import Command
from geo_assistant.agent.state import GeoAssistantState
dotenv.load_dotenv() dotenv.load_dotenv()
@@ -67,7 +70,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent()
@tool @tool
async def summarize_sat_img( async def summarize_sat_img(
img_url: str, state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str | None, InjectedToolCallId] = None, tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
) -> Command: ) -> Command:
"""Summarize the contents of a satellite image using an LLM. """Summarize the contents of a satellite image using an LLM.
@@ -82,21 +85,24 @@ async def summarize_sat_img(
Raises: Raises:
ValueError: If the image URL is invalid or the image cannot be processed ValueError: If the image URL is invalid or the image cannot be processed
""" """
if not img_url or not isinstance(img_url, str): if not state["naip_img_bytes"]:
raise ValueError("img_url must be a non-empty string")
summary = _SUMMARIZER_AGENT(img_url)
message_content = summary.answer
artifact = {"img_url": img_url}
return Command( return Command(
update={ update={
"messages": [ "messages": [
ToolMessage( ToolMessage(
content=message_content, content="No NAIP image bytes available yet",
artifact=artifact,
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
), ),
], ],
}, },
) )
img_url = f"data:image/jpeg;base64,{state['naip_img_bytes']}"
summary = _SUMMARIZER_AGENT(img_url)
message_content = summary.answer
return Command(
update={
"messages": [
ToolMessage(content=message_content, tool_call_id=tool_call_id),
],
},
)
+3 -3
View File
@@ -20,8 +20,7 @@ async def initialized_app():
@pytest.mark.xfail @pytest.mark.xfail
async def test_hello_world(initialized_app): async def test_call_api(initialized_app):
"""Hello world test for the API"""
async with AsyncClient( async with AsyncClient(
transport=ASGITransport(app=initialized_app), transport=ASGITransport(app=initialized_app),
base_url="http://test", base_url="http://test",
@@ -33,7 +32,7 @@ async def test_hello_world(initialized_app):
"agent_state_input": { "agent_state_input": {
"messages": [ "messages": [
{ {
"content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it", "content": "Find The Whitney Hotel Boston and buffer 0.1km around it, then fetch the NAIP imagery for the area from 2021 and summarize the contents of the image.",
"type": "human", "type": "human",
}, },
], ],
@@ -43,6 +42,7 @@ async def test_hello_world(initialized_app):
"thread_id": str(thread_id), "thread_id": str(thread_id),
}, },
) )
print(response)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "application/x-ndjson; charset=utf-8" assert response.headers["content-type"] == "application/x-ndjson; charset=utf-8"
+1 -1
View File
@@ -17,7 +17,7 @@ def geo_assistant_fixture():
place=place_geojson, place=place_geojson,
search_area=None, search_area=None,
messages=[], messages=[],
naip_png_path="path/to/naip.png", naip_img_bytes=None,
) )
+11 -9
View File
@@ -1,9 +1,11 @@
from types import NoneType from types import NoneType
import pytest import pytest
from geojson_pydantic import Feature
from langchain_core.tools.base import ToolCall from langchain_core.tools.base import ToolCall
from shapely.geometry import box, mapping from shapely.geometry import box, mapping
from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.tools.naip import fetch_naip_img from geo_assistant.tools.naip import fetch_naip_img
@@ -11,7 +13,7 @@ from geo_assistant.tools.naip import fetch_naip_img
async def test_fetch_naip(): async def test_fetch_naip():
""" """
Integration test: hit MPC STAC for NAIP around Union Market (DC), Integration test: hit MPC STAC for NAIP around Union Market (DC),
load imagery via odc-stac, and save an RGB PNG. load imagery via odc-stac, and save an RGB JPEG.
NOTE: This test requires: NOTE: This test requires:
- Internet access (to reach Planetary Computer STAC + blobs) - Internet access (to reach Planetary Computer STAC + blobs)
@@ -27,13 +29,13 @@ async def test_fetch_naip():
# ~0.0001 degrees buffer in each direction # ~0.0001 degrees buffer in each direction
aoi = box(lon - 0.0001, lat - 0.0001, lon + 0.0001, lat + 0.0001) aoi = box(lon - 0.0001, lat - 0.0001, lon + 0.0001, lat + 0.0001)
aoi_geojson = mapping(aoi) aoi_geojson = mapping(aoi)
aoi_feature = Feature(type="Feature", geometry=aoi_geojson, properties={})
tool_call = ToolCall( tool_call = ToolCall(
name="fetch_naip_img", name="fetch_naip_img",
args={ args={
"aoi_geojson": aoi_geojson,
"start_date": "2021-01-01", "start_date": "2021-01-01",
"end_date": "2021-12-31", "end_date": "2021-12-31",
"state": GeoAssistantState(search_area=aoi_feature, messages=[]),
}, },
type="tool_call", type="tool_call",
id="test_tool_call_id", id="test_tool_call_id",
@@ -42,9 +44,9 @@ async def test_fetch_naip():
# Call the actual tool - no STAC / odc-stac mocking # Call the actual tool - no STAC / odc-stac mocking
result = await fetch_naip_img.ainvoke(tool_call) result = await fetch_naip_img.ainvoke(tool_call)
assert "naip_img_bytes" in result.update assert "naip_img_bytes" in result.update
assert result.update["naip_img_bytes"] is not None, "Expected PNG bytes in result" assert result.update["naip_img_bytes"] is not None, "Expected JPEG bytes in result"
assert isinstance(result.update["naip_img_bytes"], bytes) assert isinstance(result.update["naip_img_bytes"], str)
assert len(result.update["naip_img_bytes"]) > 1, "Expected non-empty PNG bytes" assert len(result.update["naip_img_bytes"]) > 1, "Expected non-empty JPEG bytes"
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -68,13 +70,13 @@ async def test_fetch_naip_too_large():
# ~0.003 degrees buffer in each direction # ~0.003 degrees buffer in each direction
aoi = box(lon - 0.003, lat - 0.003, lon + 0.003, lat + 0.003) aoi = box(lon - 0.003, lat - 0.003, lon + 0.003, lat + 0.003)
aoi_geojson = mapping(aoi) aoi_geojson = mapping(aoi)
aoi_feature = Feature(type="Feature", geometry=aoi_geojson, properties={})
tool_call = ToolCall( tool_call = ToolCall(
name="fetch_naip_img", name="fetch_naip_img",
args={ args={
"aoi_geojson": aoi_geojson,
"start_date": "2021-01-01", "start_date": "2021-01-01",
"end_date": "2021-12-31", "end_date": "2021-12-31",
"state": GeoAssistantState(search_area=aoi_feature, messages=[]),
}, },
type="tool_call", type="tool_call",
id="test_tool_call_id", id="test_tool_call_id",
@@ -83,5 +85,5 @@ async def test_fetch_naip_too_large():
# Call the actual tool - no STAC / odc-stac mocking # Call the actual tool - no STAC / odc-stac mocking
result = await fetch_naip_img.ainvoke(tool_call) result = await fetch_naip_img.ainvoke(tool_call)
assert "naip_img_bytes" in result.update assert "naip_img_bytes" in result.update
assert result.update["naip_img_bytes"] is None, "Expected no PNG bytes in result" assert result.update["naip_img_bytes"] is None, "Expected no JPEG bytes in result"
assert isinstance(result.update["naip_img_bytes"], NoneType) assert isinstance(result.update["naip_img_bytes"], NoneType)
+18 -4
View File
@@ -1,29 +1,43 @@
"""Tests for the satellite image summarization tool.""" """Tests for the satellite image summarization tool."""
import base64
import uuid import uuid
import pytest import pytest
import requests
from langchain_core.tools.base import ToolCall from langchain_core.tools.base import ToolCall
from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.tools.summarize import summarize_sat_img from geo_assistant.tools.summarize import summarize_sat_img
# Sample test data # Sample test data
TEST_IMAGE_URL = "https://petapixel.com/assets/uploads/2022/08/French-Officials-Use-Satellite-Photos-and-AI-to-Spot-Unregistered-Pools-1536x806.jpg" TEST_IMAGE_URL = "https://petapixel.com/assets/uploads/2022/08/French-Officials-Use-Satellite-Photos-and-AI-to-Spot-Unregistered-Pools-1536x806.jpg"
@pytest.mark.xfail
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"img_url,summary", "img_url,summary",
[ [
(TEST_IMAGE_URL, "building"), (TEST_IMAGE_URL, "building"),
], ],
) )
@pytest.mark.xfail async def test_summarize_sat_img(img_url, summary):
def test_summarize_sat_img(img_url, summary): # Load the image from the supplied URL and encode it in base64
command = summarize_sat_img.invoke( headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
}
resp = requests.get(img_url, headers=headers)
resp.raise_for_status()
img_base64 = base64.b64encode(resp.content).decode("utf-8")
command = await summarize_sat_img.ainvoke(
ToolCall( ToolCall(
name="summarize_sat_img", name="summarize_sat_img",
type="tool_call", type="tool_call",
args={"img_url": img_url}, args={
"state": GeoAssistantState(naip_img_bytes=img_base64, messages=[]),
"tool_call_id": str(uuid.uuid4()),
},
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
), ),
) )