diff --git a/src/geo_assistant/agent/state.py b/src/geo_assistant/agent/state.py index 00ef5b4..8d0aca9 100644 --- a/src/geo_assistant/agent/state.py +++ b/src/geo_assistant/agent/state.py @@ -8,7 +8,7 @@ from pydantic import Field class GeoAssistantState(AgentState): place: NotRequired[Feature | None] = None search_area: NotRequired[Feature | None] = None - naip_img_bytes: NotRequired[bytes | None] = Field( + naip_img_bytes: NotRequired[str | None] = Field( default=None, - description="Bytes of the saved NAIP RGB PNG image", + description="Base 64 encoded bytes str of the saved NAIP RGB JPEG image", ) diff --git a/src/geo_assistant/api/app.py b/src/geo_assistant/api/app.py index c8680ea..f08a191 100644 --- a/src/geo_assistant/api/app.py +++ b/src/geo_assistant/api/app.py @@ -1,4 +1,3 @@ -import json import logging from collections.abc import AsyncGenerator from contextlib import aclosing, asynccontextmanager @@ -96,7 +95,7 @@ async def stream_chat( state = GeoAssistantState(**payload) resp = ChatResponse(thread_id=str(thread_id), state=state) - line = json.dumps(resp.model_dump()) + "\n" + line = resp.model_dump_json() + "\n" yield line.encode("utf-8") diff --git a/src/geo_assistant/frontend/app.py b/src/geo_assistant/frontend/app.py index e0fb44d..8bc5c89 100644 --- a/src/geo_assistant/frontend/app.py +++ b/src/geo_assistant/frontend/app.py @@ -1,3 +1,4 @@ +import base64 import json import os import uuid @@ -68,8 +69,16 @@ def stream_chat(user_message: str): for key, value in state.items(): if value and isinstance(value, dict) and value.get("type") == "Feature": geojson_features[key] = value - # with st.chat_message("tool"): - # st.code(json.dumps(value, indent=2), language="json") + elif value and isinstance(value, str) and key == "naip_img_bytes": + # Handle base64-encoded jpeg data + try: + img_bytes = base64.b64decode(value) + with st.chat_message("tool"): + st.image(img_bytes) + except Exception: + # If decoding fails, fall through to JSON display + with st.chat_message("tool"): + st.code(json.dumps(value, indent=2), language="json") elif value: with st.chat_message("tool"): st.code(json.dumps(value, indent=2), language="json") diff --git a/src/geo_assistant/tools/naip.py b/src/geo_assistant/tools/naip.py index bd601ce..63969d4 100644 --- a/src/geo_assistant/tools/naip.py +++ b/src/geo_assistant/tools/naip.py @@ -1,7 +1,8 @@ # tools/naip_mpc_tools.py +import base64 from concurrent.futures import ThreadPoolExecutor from io import BytesIO -from typing import Annotated, Any +from typing import Annotated import dotenv import matplotlib.pyplot as plt @@ -10,41 +11,54 @@ import xarray as xr from langchain_core.messages import ToolMessage from langchain_core.tools import tool from langchain_core.tools.base import InjectedToolCallId +from langgraph.prebuilt import InjectedState from langgraph.types import Command from odc.stac import stac_load from pystac.extensions.raster import RasterBand from pystac_client import Client +from geo_assistant.agent.state import GeoAssistantState + dotenv.load_dotenv() DATA_URL = "https://planetarycomputer.microsoft.com/api/stac/v1" -# DATA_URL = "https://earth-search.aws.element84.com/v1" @tool("fetch_naip_img") async def fetch_naip_img( - aoi_geojson: dict[str, Any], start_date: str, end_date: str, + state: Annotated[GeoAssistantState, InjectedState], tool_call_id: Annotated[str | None, InjectedToolCallId] = None, ) -> Command: """ Query Microsoft Planetary Computer for NAIP imagery intersecting an AOI and date range, load all matching items into an xarray data cube using odc-stac, - and save a simple RGB composite as a PNG. + and save a simple RGB composite as a JPEG. Args: - aoi_geojson: GeoJSON Polygon/MultiPolygon in EPSG:4326. start_date: Start date (YYYY-MM-DD). end_date: End date (YYYY-MM-DD). """ + if not state["search_area"]: + return Command( + update={ + "messages": [ + ToolMessage( + content="No search area avilable yetmee", + tool_call_id=tool_call_id, + ), + ], + "naip_img_bytes": None, + }, + ) # --- 1. STAC search on Element84's EarthSearch API --- catalog = Client.open(DATA_URL) search = catalog.search( collections=["naip"], - intersects=aoi_geojson, + intersects=state["search_area"].geometry, datetime=f"{start_date}/{end_date}", ) @@ -68,19 +82,19 @@ async def fetch_naip_img( tool_call_id=tool_call_id, ), ], - "naip_png_path": None, + "naip_img_bytes": None, }, ) # --- 2. Load as xarray cube with odc.stac --- # NAIP in MPC: 4-band multi-band asset (R,G,B,NIR) in one asset named "image". # odc.stac exposes these as measurements 'red','green','blue','nir' for this collection - + # Limit to first item for now with ThreadPoolExecutor(max_workers=5) as executor: ds: xr.Dataset = stac_load( - items, + items[:1], bands=["red", "green", "blue"], # use only RGB - geopolygon=aoi_geojson, + geopolygon=state["search_area"].geometry, resolution=1.0, # NAIP native ~1 m executor=executor, crs=items[0].properties["proj:code"], @@ -117,7 +131,7 @@ async def fetch_naip_img( ) # --- 3. Build an RGB composite from the cube --- - # For the PNG, we'll just use the first time slice (you can swap in “latest” + # For the JPEG, we'll just use the first time slice (you can swap in “latest” # or a temporal reduction if you prefer). red = ds["red"].isel(time=0) green = ds["green"].isel(time=0) @@ -127,7 +141,7 @@ async def fetch_naip_img( rgb = xr.concat([red, green, blue], dim="band") # (band, y, x) rgb = rgb.transpose("y", "x", "band") # (y, x, band) - # Convert to uint8 for PNG with a simple contrast stretch. + # Convert to uint8 for JPEG with a simple contrast stretch. arr = rgb.values.astype("float32") # Robust min/max to avoid a few hot pixels blowing out the stretch vmin = np.nanpercentile(arr, 2) @@ -138,21 +152,21 @@ async def fetch_naip_img( arr = np.clip((arr - vmin) / (vmax - vmin + 1e-6), 0, 1) arr_uint8 = (arr * 255).astype("uint8") - # --- 4. Save PNG --- + # --- 4. Save image --- buf = BytesIO() - plt.imsave(buf, arr_uint8, format="png") + plt.imsave(buf, arr_uint8, format="jpeg") buf.seek(0) - img_bytes = buf.getvalue() + img_base64 = base64.b64encode(buf.read()).decode("utf-8") return Command( update={ "messages": [ ToolMessage( - content="NAIP RGB image fetched and encoded as PNG bytes.", + content="NAIP RGB image fetched and encoded as JPEG bytes.", tool_call_id=tool_call_id, ), ], - "naip_img_bytes": img_bytes, + "naip_img_bytes": img_base64, }, ) diff --git a/src/geo_assistant/tools/summarize.py b/src/geo_assistant/tools/summarize.py index 70b38fe..01bd2fb 100644 --- a/src/geo_assistant/tools/summarize.py +++ b/src/geo_assistant/tools/summarize.py @@ -8,8 +8,11 @@ import dspy from langchain_core.messages import ToolMessage from langchain_core.tools import tool from langchain_core.tools.base import InjectedToolCallId +from langgraph.prebuilt import InjectedState from langgraph.types import Command +from geo_assistant.agent.state import GeoAssistantState + dotenv.load_dotenv() @@ -67,7 +70,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent() @tool async def summarize_sat_img( - img_url: str, + state: Annotated[GeoAssistantState, InjectedState], tool_call_id: Annotated[str | None, InjectedToolCallId] = None, ) -> Command: """Summarize the contents of a satellite image using an LLM. @@ -82,21 +85,24 @@ async def summarize_sat_img( Raises: ValueError: If the image URL is invalid or the image cannot be processed """ - if not img_url or not isinstance(img_url, str): - raise ValueError("img_url must be a non-empty string") - + if not state["naip_img_bytes"]: + return Command( + update={ + "messages": [ + ToolMessage( + content="No NAIP image bytes available yet", + tool_call_id=tool_call_id, + ), + ], + }, + ) + img_url = f"data:image/jpeg;base64,{state['naip_img_bytes']}" summary = _SUMMARIZER_AGENT(img_url) message_content = summary.answer - artifact = {"img_url": img_url} - return Command( update={ "messages": [ - ToolMessage( - content=message_content, - artifact=artifact, - tool_call_id=tool_call_id, - ), + ToolMessage(content=message_content, tool_call_id=tool_call_id), ], }, ) diff --git a/tests/test_api.py b/tests/test_api.py index a459e14..796ed96 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -20,8 +20,7 @@ async def initialized_app(): @pytest.mark.xfail -async def test_hello_world(initialized_app): - """Hello world test for the API""" +async def test_call_api(initialized_app): async with AsyncClient( transport=ASGITransport(app=initialized_app), base_url="http://test", @@ -33,7 +32,7 @@ async def test_hello_world(initialized_app): "agent_state_input": { "messages": [ { - "content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it", + "content": "Find The Whitney Hotel Boston and buffer 0.1km around it, then fetch the NAIP imagery for the area from 2021 and summarize the contents of the image.", "type": "human", }, ], @@ -43,6 +42,7 @@ async def test_hello_world(initialized_app): "thread_id": str(thread_id), }, ) + print(response) assert response.status_code == 200 assert response.headers["content-type"] == "application/x-ndjson; charset=utf-8" diff --git a/tests/tools/test_buffer.py b/tests/tools/test_buffer.py index 03c7404..44ab50d 100644 --- a/tests/tools/test_buffer.py +++ b/tests/tools/test_buffer.py @@ -17,7 +17,7 @@ def geo_assistant_fixture(): place=place_geojson, search_area=None, messages=[], - naip_png_path="path/to/naip.png", + naip_img_bytes=None, ) diff --git a/tests/tools/test_naip.py b/tests/tools/test_naip.py index d6c12ef..269769a 100644 --- a/tests/tools/test_naip.py +++ b/tests/tools/test_naip.py @@ -1,9 +1,11 @@ from types import NoneType import pytest +from geojson_pydantic import Feature from langchain_core.tools.base import ToolCall from shapely.geometry import box, mapping +from geo_assistant.agent.state import GeoAssistantState from geo_assistant.tools.naip import fetch_naip_img @@ -11,7 +13,7 @@ from geo_assistant.tools.naip import fetch_naip_img async def test_fetch_naip(): """ Integration test: hit MPC STAC for NAIP around Union Market (DC), - load imagery via odc-stac, and save an RGB PNG. + load imagery via odc-stac, and save an RGB JPEG. NOTE: This test requires: - Internet access (to reach Planetary Computer STAC + blobs) @@ -27,13 +29,13 @@ async def test_fetch_naip(): # ~0.0001 degrees buffer in each direction aoi = box(lon - 0.0001, lat - 0.0001, lon + 0.0001, lat + 0.0001) aoi_geojson = mapping(aoi) - + aoi_feature = Feature(type="Feature", geometry=aoi_geojson, properties={}) tool_call = ToolCall( name="fetch_naip_img", args={ - "aoi_geojson": aoi_geojson, "start_date": "2021-01-01", "end_date": "2021-12-31", + "state": GeoAssistantState(search_area=aoi_feature, messages=[]), }, type="tool_call", id="test_tool_call_id", @@ -42,9 +44,9 @@ async def test_fetch_naip(): # Call the actual tool - no STAC / odc-stac mocking result = await fetch_naip_img.ainvoke(tool_call) assert "naip_img_bytes" in result.update - assert result.update["naip_img_bytes"] is not None, "Expected PNG bytes in result" - assert isinstance(result.update["naip_img_bytes"], bytes) - assert len(result.update["naip_img_bytes"]) > 1, "Expected non-empty PNG bytes" + assert result.update["naip_img_bytes"] is not None, "Expected JPEG bytes in result" + assert isinstance(result.update["naip_img_bytes"], str) + assert len(result.update["naip_img_bytes"]) > 1, "Expected non-empty JPEG bytes" @pytest.mark.asyncio @@ -68,13 +70,13 @@ async def test_fetch_naip_too_large(): # ~0.003 degrees buffer in each direction aoi = box(lon - 0.003, lat - 0.003, lon + 0.003, lat + 0.003) aoi_geojson = mapping(aoi) - + aoi_feature = Feature(type="Feature", geometry=aoi_geojson, properties={}) tool_call = ToolCall( name="fetch_naip_img", args={ - "aoi_geojson": aoi_geojson, "start_date": "2021-01-01", "end_date": "2021-12-31", + "state": GeoAssistantState(search_area=aoi_feature, messages=[]), }, type="tool_call", id="test_tool_call_id", @@ -83,5 +85,5 @@ async def test_fetch_naip_too_large(): # Call the actual tool - no STAC / odc-stac mocking result = await fetch_naip_img.ainvoke(tool_call) assert "naip_img_bytes" in result.update - assert result.update["naip_img_bytes"] is None, "Expected no PNG bytes in result" + assert result.update["naip_img_bytes"] is None, "Expected no JPEG bytes in result" assert isinstance(result.update["naip_img_bytes"], NoneType) diff --git a/tests/tools/test_summarize.py b/tests/tools/test_summarize.py index 9d629c9..532ab4e 100644 --- a/tests/tools/test_summarize.py +++ b/tests/tools/test_summarize.py @@ -1,29 +1,43 @@ """Tests for the satellite image summarization tool.""" +import base64 import uuid import pytest +import requests from langchain_core.tools.base import ToolCall +from geo_assistant.agent.state import GeoAssistantState from geo_assistant.tools.summarize import summarize_sat_img # Sample test data TEST_IMAGE_URL = "https://petapixel.com/assets/uploads/2022/08/French-Officials-Use-Satellite-Photos-and-AI-to-Spot-Unregistered-Pools-1536x806.jpg" +@pytest.mark.xfail +@pytest.mark.asyncio @pytest.mark.parametrize( "img_url,summary", [ (TEST_IMAGE_URL, "building"), ], ) -@pytest.mark.xfail -def test_summarize_sat_img(img_url, summary): - command = summarize_sat_img.invoke( +async def test_summarize_sat_img(img_url, summary): + # Load the image from the supplied URL and encode it in base64 + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + } + resp = requests.get(img_url, headers=headers) + resp.raise_for_status() + img_base64 = base64.b64encode(resp.content).decode("utf-8") + command = await summarize_sat_img.ainvoke( ToolCall( name="summarize_sat_img", type="tool_call", - args={"img_url": img_url}, + args={ + "state": GeoAssistantState(naip_img_bytes=img_base64, messages=[]), + "tool_call_id": str(uuid.uuid4()), + }, id=str(uuid.uuid4()), ), )