mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-15 15:31:02 +02:00
Tool/naip fetcher (#16)
* Add initial NAIP fetcher * Swap to Element84's EarthSearch API for NAIP STAC search and download * clip to bounds of aoi * Swap to Element84's EarthSearch API for NAIP STAC search and download * rename bands and remove dask chunking * Add DS_Store to .gitignore * restrict date range for naip test * Adjust timerange for tests * Add xarray to pyproj * Reduce aoi size * revert test to use tmp path * Update return types for tool to ensure state gets updated * Update unit test for compatibility with Command output from tool * Save image bytes directly to graph state for summarizer * add safeguard against large image sizes * remove print statement * Fix stac.load to work with MCP API by manually inserting band data into the extension * Remove deleted file * Add comment explaining workaround --------- Co-authored-by: lillythomas <lillyelizathomas@gmail.com> Co-authored-by: Daniel Wiesmann <yellowcap@users.noreply.github.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from pydantic import Field
|
|||||||
class GeoAssistantState(AgentState):
|
class GeoAssistantState(AgentState):
|
||||||
place: NotRequired[Feature | None] = None
|
place: NotRequired[Feature | None] = None
|
||||||
search_area: NotRequired[Feature | None] = None
|
search_area: NotRequired[Feature | None] = None
|
||||||
naip_png_path: NotRequired[str | None] = Field(
|
naip_img_bytes: NotRequired[bytes | None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to the saved NAIP RGB PNG image",
|
description="Bytes of the saved NAIP RGB PNG image",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# tools/naip_mpc_tools.py
|
# tools/naip_mpc_tools.py
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from io import BytesIO
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
@@ -12,12 +12,13 @@ from langchain_core.tools import tool
|
|||||||
from langchain_core.tools.base import InjectedToolCallId
|
from langchain_core.tools.base import InjectedToolCallId
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
from odc.stac import stac_load
|
from odc.stac import stac_load
|
||||||
|
from pystac.extensions.raster import RasterBand
|
||||||
from pystac_client import Client
|
from pystac_client import Client
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
# PC_STAC_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"
|
DATA_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"
|
||||||
E84_STAC_URL = "https://earth-search.aws.element84.com/v1"
|
# DATA_URL = "https://earth-search.aws.element84.com/v1"
|
||||||
|
|
||||||
|
|
||||||
@tool("fetch_naip_img")
|
@tool("fetch_naip_img")
|
||||||
@@ -39,7 +40,7 @@ async def fetch_naip_img(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# --- 1. STAC search on Element84's EarthSearch API ---
|
# --- 1. STAC search on Element84's EarthSearch API ---
|
||||||
catalog = Client.open(E84_STAC_URL)
|
catalog = Client.open(DATA_URL)
|
||||||
|
|
||||||
search = catalog.search(
|
search = catalog.search(
|
||||||
collections=["naip"],
|
collections=["naip"],
|
||||||
@@ -48,6 +49,16 @@ async def fetch_naip_img(
|
|||||||
)
|
)
|
||||||
|
|
||||||
items = list(search.items())
|
items = list(search.items())
|
||||||
|
|
||||||
|
# This is a hack to add raster extension info to the items, since
|
||||||
|
# the Planetary Computer STAC API adds the band information using the
|
||||||
|
# eo:bands extension, but odc.stac expects the raster:bands extension.
|
||||||
|
for item in items:
|
||||||
|
item.assets["image"].ext.add("raster")
|
||||||
|
item.assets["image"].ext.raster.bands = [
|
||||||
|
RasterBand.create() for _ in ("red", "green", "blue", "nir")
|
||||||
|
]
|
||||||
|
|
||||||
if len(items) == 0:
|
if len(items) == 0:
|
||||||
return Command(
|
return Command(
|
||||||
update={
|
update={
|
||||||
@@ -61,18 +72,20 @@ async def fetch_naip_img(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 2. Load as xarray cube with odc.stac ---
|
# --- 2. Load as xarray cube with odc.stac ---
|
||||||
# NAIP in MPC: 4-band multi-band asset (R,G,B,NIR) in one asset named "image".
|
# NAIP in MPC: 4-band multi-band asset (R,G,B,NIR) in one asset named "image".
|
||||||
# odc.stac exposes these as measurements 'red','green','blue','nir' for this collection
|
# odc.stac exposes these as measurements 'red','green','blue','nir' for this collection
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
ds: xr.Dataset = stac_load(
|
ds: xr.Dataset = stac_load(
|
||||||
items,
|
items,
|
||||||
bands=["Red", "Green", "Blue"], # use only RGB
|
bands=["red", "green", "blue"], # use only RGB
|
||||||
geopolygon=aoi_geojson,
|
geopolygon=aoi_geojson,
|
||||||
resolution=1.0, # NAIP native ~1 m
|
resolution=1.0, # NAIP native ~1 m
|
||||||
executor=executor,
|
executor=executor,
|
||||||
|
crs=items[0].properties["proj:code"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if ds.dims.get("time", 0) == 0:
|
if ds.dims.get("time", 0) == 0:
|
||||||
return Command(
|
return Command(
|
||||||
update={
|
update={
|
||||||
@@ -82,16 +95,33 @@ async def fetch_naip_img(
|
|||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
"naip_png_path": None,
|
"naip_img_bytes": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enforce max output size based on dataset sizes (y, x)
|
||||||
|
sizes = dict(ds.sizes)
|
||||||
|
h = int(sizes.get("y", 0))
|
||||||
|
w = int(sizes.get("x", 0))
|
||||||
|
if h > 512 or w > 512:
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=f"NAIP RGB image {w}x{h} exceeds 512x512 limit. Skipping image output.",
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"naip_img_bytes": None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 3. Build an RGB composite from the cube ---
|
# --- 3. Build an RGB composite from the cube ---
|
||||||
# For the PNG, we'll just use the first time slice (you can swap in “latest”
|
# For the PNG, we'll just use the first time slice (you can swap in “latest”
|
||||||
# or a temporal reduction if you prefer).
|
# or a temporal reduction if you prefer).
|
||||||
red = ds["Red"].isel(time=0)
|
red = ds["red"].isel(time=0)
|
||||||
green = ds["Green"].isel(time=0)
|
green = ds["green"].isel(time=0)
|
||||||
blue = ds["Blue"].isel(time=0)
|
blue = ds["blue"].isel(time=0)
|
||||||
|
|
||||||
# Stack into (y, x, 3) array
|
# Stack into (y, x, 3) array
|
||||||
rgb = xr.concat([red, green, blue], dim="band") # (band, y, x)
|
rgb = xr.concat([red, green, blue], dim="band") # (band, y, x)
|
||||||
@@ -110,18 +140,19 @@ async def fetch_naip_img(
|
|||||||
|
|
||||||
# --- 4. Save PNG ---
|
# --- 4. Save PNG ---
|
||||||
|
|
||||||
out_path = Path("naip_rgb.png")
|
buf = BytesIO()
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
plt.imsave(buf, arr_uint8, format="png")
|
||||||
plt.imsave(out_path.as_posix(), arr_uint8)
|
buf.seek(0)
|
||||||
|
img_bytes = buf.getvalue()
|
||||||
|
|
||||||
return Command(
|
return Command(
|
||||||
update={
|
update={
|
||||||
"messages": [
|
"messages": [
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=f"NAIP RGB image saved to {out_path.as_posix()}",
|
content="NAIP RGB image fetched and encoded as PNG bytes.",
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
"naip_png_path": out_path.as_posix(),
|
"naip_img_bytes": img_bytes,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
+46
-18
@@ -1,4 +1,4 @@
|
|||||||
from pathlib import Path
|
from types import NoneType
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.tools.base import ToolCall
|
from langchain_core.tools.base import ToolCall
|
||||||
@@ -8,8 +8,7 @@ from geo_assistant.tools.naip import fetch_naip_img
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.xfail
|
async def test_fetch_naip():
|
||||||
async def test_fetch_naip(tmp_path):
|
|
||||||
"""
|
"""
|
||||||
Integration test: hit MPC STAC for NAIP around Union Market (DC),
|
Integration test: hit MPC STAC for NAIP around Union Market (DC),
|
||||||
load imagery via odc-stac, and save an RGB PNG.
|
load imagery via odc-stac, and save an RGB PNG.
|
||||||
@@ -22,38 +21,67 @@ async def test_fetch_naip(tmp_path):
|
|||||||
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
||||||
# N 38°54'28" W 76°59'54"
|
# N 38°54'28" W 76°59'54"
|
||||||
# We'll use a small neighborhood AOI around that point.
|
# We'll use a small neighborhood AOI around that point.
|
||||||
# :contentReference[oaicite:0]{index=0}
|
|
||||||
lat = 38.90789
|
lat = 38.90789
|
||||||
lon = -76.99831
|
lon = -76.99831
|
||||||
|
|
||||||
# ~0.01 degrees (~1.1 km) buffer in each direction
|
# ~0.0001 degrees buffer in each direction
|
||||||
aoi = box(lon - 0.0001, lat - 0.0001, lon + 0.0001, lat + 0.0001)
|
aoi = box(lon - 0.0001, lat - 0.0001, lon + 0.0001, lat + 0.0001)
|
||||||
aoi_geojson = mapping(aoi)
|
aoi_geojson = mapping(aoi)
|
||||||
|
|
||||||
out_png = tmp_path / "naip_test_img.png"
|
|
||||||
|
|
||||||
tool_call = ToolCall(
|
tool_call = ToolCall(
|
||||||
name="fetch_naip_img",
|
name="fetch_naip_img",
|
||||||
args={
|
args={
|
||||||
"aoi_geojson": aoi_geojson,
|
"aoi_geojson": aoi_geojson,
|
||||||
"start_date": "2021-01-01",
|
"start_date": "2021-01-01",
|
||||||
"end_date": "2021-12-31",
|
"end_date": "2021-12-31",
|
||||||
"out_png_path": str(out_png),
|
|
||||||
"resolution": 1.0,
|
|
||||||
},
|
},
|
||||||
type="tool_call",
|
type="tool_call",
|
||||||
id="test_tool_call_id",
|
id="test_tool_call_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the actual tool - no STAC / odc-stac mocking
|
# Call the actual tool - no STAC / odc-stac mocking
|
||||||
result = await fetch_naip_img.ainvoke(tool_call["args"])
|
result = await fetch_naip_img.ainvoke(tool_call)
|
||||||
|
assert "naip_img_bytes" in result.update
|
||||||
|
assert result.update["naip_img_bytes"] is not None, "Expected PNG bytes in result"
|
||||||
|
assert isinstance(result.update["naip_img_bytes"], bytes)
|
||||||
|
assert len(result.update["naip_img_bytes"]) > 1, "Expected non-empty PNG bytes"
|
||||||
|
|
||||||
# Basic sanity checks on result
|
|
||||||
assert result["stac_item_count"] > 0, "Expected at least one NAIP item"
|
|
||||||
assert "time" in result["dataset_dims"]
|
|
||||||
assert result["dataset_dims"]["time"] >= 1
|
|
||||||
|
|
||||||
# PNG should have been written to disk
|
@pytest.mark.asyncio
|
||||||
png_path = Path(result["png_path"])
|
async def test_fetch_naip_too_large():
|
||||||
assert png_path == out_png
|
"""
|
||||||
assert png_path.is_file(), f"PNG was not created at {png_path}"
|
Integration test: request a larger AOI that should produce an image
|
||||||
|
exceeding the 512x512 pixel limit. The tool should return no image
|
||||||
|
bytes and include a message indicating it skipped output due to size.
|
||||||
|
|
||||||
|
NOTE: This test requires:
|
||||||
|
- Internet access (to reach Planetary Computer STAC + blobs)
|
||||||
|
- Planetary Computer / NAIP service to be up
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Union Market coordinates from GeoNames: 38.90789, -76.99831
|
||||||
|
# N 38°54'28" W 76°59'54"
|
||||||
|
# We'll use a small neighborhood AOI around that point.
|
||||||
|
lat = 38.90789
|
||||||
|
lon = -76.99831
|
||||||
|
|
||||||
|
# ~0.003 degrees buffer in each direction
|
||||||
|
aoi = box(lon - 0.003, lat - 0.003, lon + 0.003, lat + 0.003)
|
||||||
|
aoi_geojson = mapping(aoi)
|
||||||
|
|
||||||
|
tool_call = ToolCall(
|
||||||
|
name="fetch_naip_img",
|
||||||
|
args={
|
||||||
|
"aoi_geojson": aoi_geojson,
|
||||||
|
"start_date": "2021-01-01",
|
||||||
|
"end_date": "2021-12-31",
|
||||||
|
},
|
||||||
|
type="tool_call",
|
||||||
|
id="test_tool_call_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the actual tool - no STAC / odc-stac mocking
|
||||||
|
result = await fetch_naip_img.ainvoke(tool_call)
|
||||||
|
assert "naip_img_bytes" in result.update
|
||||||
|
assert result.update["naip_img_bytes"] is None, "Expected no PNG bytes in result"
|
||||||
|
assert isinstance(result.update["naip_img_bytes"], NoneType)
|
||||||
|
|||||||
Reference in New Issue
Block a user