Render naip and summarize (#20)

* Intermediate

* Fix naip geom handling

* Fix imagery decoding for summary tool

* Re enable xfail

* Re enable xfail

* Remove png references
This commit is contained in:
Daniel Wiesmann
2025-12-05 15:39:16 +00:00
committed by GitHub
parent dddac818ea
commit 9b202c504e
9 changed files with 95 additions and 51 deletions
+2 -2
View File
@@ -8,7 +8,7 @@ from pydantic import Field
class GeoAssistantState(AgentState):
place: NotRequired[Feature | None] = None
search_area: NotRequired[Feature | None] = None
naip_img_bytes: NotRequired[bytes | None] = Field(
naip_img_bytes: NotRequired[str | None] = Field(
default=None,
description="Bytes of the saved NAIP RGB PNG image",
description="Base 64 encoded bytes str of the saved NAIP RGB JPEG image",
)
+1 -2
View File
@@ -1,4 +1,3 @@
import json
import logging
from collections.abc import AsyncGenerator
from contextlib import aclosing, asynccontextmanager
@@ -96,7 +95,7 @@ async def stream_chat(
state = GeoAssistantState(**payload)
resp = ChatResponse(thread_id=str(thread_id), state=state)
line = json.dumps(resp.model_dump()) + "\n"
line = resp.model_dump_json() + "\n"
yield line.encode("utf-8")
+11 -2
View File
@@ -1,3 +1,4 @@
import base64
import json
import os
import uuid
@@ -68,8 +69,16 @@ def stream_chat(user_message: str):
for key, value in state.items():
if value and isinstance(value, dict) and value.get("type") == "Feature":
geojson_features[key] = value
# with st.chat_message("tool"):
# st.code(json.dumps(value, indent=2), language="json")
elif value and isinstance(value, str) and key == "naip_img_bytes":
# Handle base64-encoded jpeg data
try:
img_bytes = base64.b64decode(value)
with st.chat_message("tool"):
st.image(img_bytes)
except Exception:
# If decoding fails, fall through to JSON display
with st.chat_message("tool"):
st.code(json.dumps(value, indent=2), language="json")
elif value:
with st.chat_message("tool"):
st.code(json.dumps(value, indent=2), language="json")
+31 -17
View File
@@ -1,7 +1,8 @@
# tools/naip_mpc_tools.py
import base64
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Annotated, Any
from typing import Annotated
import dotenv
import matplotlib.pyplot as plt
@@ -10,41 +11,54 @@ import xarray as xr
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.prebuilt import InjectedState
from langgraph.types import Command
from odc.stac import stac_load
from pystac.extensions.raster import RasterBand
from pystac_client import Client
from geo_assistant.agent.state import GeoAssistantState
dotenv.load_dotenv()
DATA_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"
# DATA_URL = "https://earth-search.aws.element84.com/v1"
@tool("fetch_naip_img")
async def fetch_naip_img(
aoi_geojson: dict[str, Any],
start_date: str,
end_date: str,
state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
) -> Command:
"""
Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and
date range, load all matching items into an xarray data cube using odc-stac,
and save a simple RGB composite as a PNG.
and save a simple RGB composite as a JPEG.
Args:
aoi_geojson: GeoJSON Polygon/MultiPolygon in EPSG:4326.
start_date: Start date (YYYY-MM-DD).
end_date: End date (YYYY-MM-DD).
"""
if not state["search_area"]:
return Command(
update={
"messages": [
ToolMessage(
content="No search area avilable yetmee",
tool_call_id=tool_call_id,
),
],
"naip_img_bytes": None,
},
)
# --- 1. STAC search on Element84's EarthSearch API ---
catalog = Client.open(DATA_URL)
search = catalog.search(
collections=["naip"],
intersects=aoi_geojson,
intersects=state["search_area"].geometry,
datetime=f"{start_date}/{end_date}",
)
@@ -68,19 +82,19 @@ async def fetch_naip_img(
tool_call_id=tool_call_id,
),
],
"naip_png_path": None,
"naip_img_bytes": None,
},
)
# --- 2. Load as xarray cube with odc.stac ---
# NAIP in MPC: 4-band multi-band asset (R,G,B,NIR) in one asset named "image".
# odc.stac exposes these as measurements 'red','green','blue','nir' for this collection
# Limit to first item for now
with ThreadPoolExecutor(max_workers=5) as executor:
ds: xr.Dataset = stac_load(
items,
items[:1],
bands=["red", "green", "blue"], # use only RGB
geopolygon=aoi_geojson,
geopolygon=state["search_area"].geometry,
resolution=1.0, # NAIP native ~1 m
executor=executor,
crs=items[0].properties["proj:code"],
@@ -117,7 +131,7 @@ async def fetch_naip_img(
)
# --- 3. Build an RGB composite from the cube ---
# For the PNG, we'll just use the first time slice (you can swap in “latest”
# For the JPEG, we'll just use the first time slice (you can swap in “latest”
# or a temporal reduction if you prefer).
red = ds["red"].isel(time=0)
green = ds["green"].isel(time=0)
@@ -127,7 +141,7 @@ async def fetch_naip_img(
rgb = xr.concat([red, green, blue], dim="band") # (band, y, x)
rgb = rgb.transpose("y", "x", "band") # (y, x, band)
# Convert to uint8 for PNG with a simple contrast stretch.
# Convert to uint8 for JPEG with a simple contrast stretch.
arr = rgb.values.astype("float32")
# Robust min/max to avoid a few hot pixels blowing out the stretch
vmin = np.nanpercentile(arr, 2)
@@ -138,21 +152,21 @@ async def fetch_naip_img(
arr = np.clip((arr - vmin) / (vmax - vmin + 1e-6), 0, 1)
arr_uint8 = (arr * 255).astype("uint8")
# --- 4. Save PNG ---
# --- 4. Save image ---
buf = BytesIO()
plt.imsave(buf, arr_uint8, format="png")
plt.imsave(buf, arr_uint8, format="jpeg")
buf.seek(0)
img_bytes = buf.getvalue()
img_base64 = base64.b64encode(buf.read()).decode("utf-8")
return Command(
update={
"messages": [
ToolMessage(
content="NAIP RGB image fetched and encoded as PNG bytes.",
content="NAIP RGB image fetched and encoded as JPEG bytes.",
tool_call_id=tool_call_id,
),
],
"naip_img_bytes": img_bytes,
"naip_img_bytes": img_base64,
},
)
+17 -11
View File
@@ -8,8 +8,11 @@ import dspy
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.prebuilt import InjectedState
from langgraph.types import Command
from geo_assistant.agent.state import GeoAssistantState
dotenv.load_dotenv()
@@ -67,7 +70,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent()
@tool
async def summarize_sat_img(
img_url: str,
state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
) -> Command:
"""Summarize the contents of a satellite image using an LLM.
@@ -82,21 +85,24 @@ async def summarize_sat_img(
Raises:
ValueError: If the image URL is invalid or the image cannot be processed
"""
if not img_url or not isinstance(img_url, str):
raise ValueError("img_url must be a non-empty string")
if not state["naip_img_bytes"]:
return Command(
update={
"messages": [
ToolMessage(
content="No NAIP image bytes available yet",
tool_call_id=tool_call_id,
),
],
},
)
img_url = f"data:image/jpeg;base64,{state['naip_img_bytes']}"
summary = _SUMMARIZER_AGENT(img_url)
message_content = summary.answer
artifact = {"img_url": img_url}
return Command(
update={
"messages": [
ToolMessage(
content=message_content,
artifact=artifact,
tool_call_id=tool_call_id,
),
ToolMessage(content=message_content, tool_call_id=tool_call_id),
],
},
)