feat: overture places within buffer tool (#18)

* wip query places within buffer

* places within search area query working

* make async

* remove unused fields

* formatting

* remove literal type for more flexibile, but still map user input to 3 high level Overture categories- cafe, restaurant, and bar

* actually return pydantic FeatureCollection

* better formatting for tool message

* fix formatting

* fix init

* actually fix imports

* fix linting

* clearer agent instructions

* render the FeatureCollection automatically in steamlit
This commit is contained in:
Martha Morrissey
2025-12-05 10:28:37 -07:00
committed by GitHub
parent da5ebb6601
commit c677057b2e
6 changed files with 245 additions and 21 deletions
+4 -1
View File
@@ -8,6 +8,7 @@ from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.tools import (
fetch_naip_img,
get_place,
get_places_within_buffer,
get_search_area,
summarize_sat_img,
)
@@ -19,10 +20,11 @@ You have the following tools available to you.
- get_place: Get a place from the Overture Maps database
- get_search_area: Get a search area buffer in km around the place defined in the agent state
- get_places_within_buffer: Get places from the Overture Maps database within the search area defined in the agent state
- summarize_sat_img: Summarize the contents of a satellite image using an LLM
- fetch_naip_img: A NAIP imagery fetch tool. Use this to fetch NAIP aerial imagery for a given area of interest returned by the overture location lookup tool and date range (do your best to extract the date range from the user's query if provided, otherwise ask the user to specify a date range)
For places if you have links to social media, include them in the response.
Do not use background knowledge, only use the tools above to answer questions.
The current date and time is {now}.
"""
@@ -35,6 +37,7 @@ async def create_graph():
tools=[
get_place,
get_search_area,
get_places_within_buffer,
fetch_naip_img,
summarize_sat_img,
],
+2 -1
View File
@@ -1,6 +1,6 @@
from typing import NotRequired
from geojson_pydantic import Feature
from geojson_pydantic import Feature, FeatureCollection
from langchain.agents import AgentState
from pydantic import Field
@@ -8,6 +8,7 @@ from pydantic import Field
class GeoAssistantState(AgentState):
place: NotRequired[Feature | None] = None
search_area: NotRequired[Feature | None] = None
places_within_buffer: NotRequired[FeatureCollection | None] = None
naip_img_bytes: NotRequired[str | None] = Field(
default=None,
description="Base 64 encoded bytes str of the saved NAIP RGB JPEG image",
+44 -16
View File
@@ -67,7 +67,11 @@ def stream_chat(user_message: str):
# Check for GeoJSON features and render map if present
geojson_features = {}
for key, value in state.items():
if value and isinstance(value, dict) and value.get("type") == "Feature":
if (
value
and isinstance(value, dict)
and value.get("type") in ["Feature", "FeatureCollection"]
):
geojson_features[key] = value
elif value and isinstance(value, str) and key == "naip_img_bytes":
# Handle base64-encoded jpeg data
@@ -107,13 +111,20 @@ def stream_chat(user_message: str):
# Calculate center from all features
all_lons, all_lats = [], []
for feature in geojson_features.values():
geom = feature.get("geometry", {})
coords = get_coords_from_geometry(geom)
for coord in coords:
if len(coord) >= 2:
all_lons.append(coord[0])
all_lats.append(coord[1])
for feature_data in geojson_features.values():
# Handle both Feature and FeatureCollection
if feature_data.get("type") == "FeatureCollection":
features = feature_data.get("features", [])
else:
features = [feature_data]
for feature in features:
geom = feature.get("geometry", {})
coords = get_coords_from_geometry(geom)
for coord in coords:
if len(coord) >= 2:
all_lons.append(coord[0])
all_lats.append(coord[1])
if all_lons and all_lats:
center_lat = sum(all_lats) / len(all_lats)
@@ -124,7 +135,11 @@ def stream_chat(user_message: str):
m = folium.Map(location=[center_lat, center_lon], zoom_start=10)
# Add features to map with different colors
colors = {"place": "blue", "search_area": "red"}
colors = {
"place": "blue",
"search_area": "red",
"places_within_buffer": "green",
}
def make_style_function(color):
"""Create a style function with the given color."""
@@ -135,13 +150,26 @@ def stream_chat(user_message: str):
"fillOpacity": 0.3,
}
for key, feature in geojson_features.items():
color = colors.get(key, "green")
folium.GeoJson(
feature,
style_function=make_style_function(color),
tooltip=key,
).add_to(m)
for key, feature_data in geojson_features.items():
color = colors.get(key, "purple")
# Handle both Feature and FeatureCollection
if feature_data.get("type") == "FeatureCollection":
# For FeatureCollections, add each feature with a popup showing its name
for feature in feature_data.get("features", []):
name = feature.get("properties", {}).get("name", key)
folium.GeoJson(
feature,
style_function=make_style_function(color),
tooltip=name,
).add_to(m)
else:
# For single Features
folium.GeoJson(
feature_data,
style_function=make_style_function(color),
tooltip=key,
).add_to(m)
# Fit map to bounds if we have coordinates
if all_lons and all_lats:
+8 -2
View File
@@ -1,6 +1,12 @@
from geo_assistant.tools.buffer import get_search_area
from geo_assistant.tools.naip import fetch_naip_img
from geo_assistant.tools.overture import get_place
from geo_assistant.tools.overture import get_place, get_places_within_buffer
from geo_assistant.tools.summarize import summarize_sat_img
__all__ = ["fetch_naip_img", "get_place", "get_search_area", "summarize_sat_img"]
__all__ = [
"fetch_naip_img",
"get_place",
"get_places_within_buffer",
"get_search_area",
"summarize_sat_img",
]
+132 -1
View File
@@ -3,12 +3,18 @@ import os
from typing import Annotated
import duckdb
import geopandas as gpd
import numpy as np
from dotenv import load_dotenv
from geojson_pydantic import Feature
from geojson_pydantic import Feature, FeatureCollection
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.prebuilt import InjectedState
from langgraph.types import Command
from shapely.geometry import shape
from geo_assistant.agent.state import GeoAssistantState
# Load environment variables
load_dotenv()
@@ -94,3 +100,128 @@ async def get_place(
],
},
)
def normalize_place_type(place: str) -> str:
"""Normalize place type input to Overture categories."""
place_lower = place.lower().strip()
# Mapping of variations to canonical Overture types
mappings = {
# Restaurants
"restaurant": "restaurant",
"restaurants": "restaurant",
# Cafes/Coffee shops
"cafe": "cafe",
"cafes": "cafe",
"coffee": "cafe",
"coffee shop": "cafe",
"coffee shops": "cafe",
"coffeeshop": "cafe",
# Bars
"bar": "bar",
"bars": "bar",
"pub": "bar",
"pubs": "bar",
}
return mappings.get(place_lower, place_lower)
def _format_places_within_buffer_message(gdf: gpd.GeoDataFrame) -> str:
"""Format GeoDataFrame of places into a readable message."""
count = len(gdf)
if count == 0:
return "No places found matching your criteria."
places_list = []
for _, row in gdf.iterrows():
name = row.get("name", "Unknown")
websites = row.get("websites")
# Handle case where websites might not be present
website = (
websites[0]
if isinstance(websites, np.ndarray) and len(websites) > 0
else None
)
if website:
places_list.append(f"{name} - {website}")
else:
places_list.append(f"{name}")
formatted_places = "\n".join(places_list)
return f"Found {count} places:\n{formatted_places}"
@tool
async def get_places_within_buffer(
place: str,
state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
"""Get places from Overture Maps within user specified area and user specified Overture place type.
Accepts: restaurant(s), cafe(s), coffee shop(s), bar(s), pub(s) - case insensitive."""
# Normalize the place type
place = normalize_place_type(place)
# get bounds of buffered place
search_area = state["search_area"]
db_connection = create_database_connection()
source = os.getenv("OVERTURE_SOURCE", "local")
if source == "s3":
data_path = os.getenv("OVERTURE_S3_PATH")
db_connection.execute("SET s3_region='us-west-2';")
else:
data_path = os.getenv("OVERTURE_LOCAL_PATH")
places_df = db_connection.execute(
f"""
LOAD spatial;
SELECT
id,
names.primary AS name,
ST_AsGeoJSON(geometry) AS geometry,
websites,
socials,
categories
FROM read_parquet(
'{data_path}',
filename=true,
hive_partitioning=1
)
WHERE ST_Intersects(geometry, ST_GeomFromGeoJSON('{json.dumps(search_area.geometry.model_dump())}'))
AND categories.primary = '{place}'
LIMIT 10;
""",
).fetchdf()
db_connection.close()
# Convert geometry column from GeoJSON strings to shapely geometries
places_df["geometry"] = places_df["geometry"].apply(lambda x: shape(json.loads(x)))
# Create GeoDataFrame
gdf = gpd.GeoDataFrame(places_df, geometry="geometry", crs="EPSG:4326")
# Convert to GeoJSON FeatureCollection and ensure no numpy arrays
feature_collection = FeatureCollection.model_validate(
json.loads(json.dumps(gdf.__geo_interface__, default=str)),
)
return Command(
update={
"places_within_buffer": feature_collection,
"messages": [
ToolMessage(
content=_format_places_within_buffer_message(gdf),
tool_call_id=tool_call_id,
),
],
},
)
+55
View File
@@ -1,9 +1,14 @@
import os
import geopandas as gpd
import pytest
from geojson_pydantic import Feature, Point
from langchain_core.tools.base import ToolCall
from shapely.geometry import Point as ShapelyPoint
from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.tools.overture import get_place
from src.geo_assistant.tools.overture import get_places_within_buffer
@pytest.fixture(autouse=True)
@@ -18,6 +23,39 @@ def setup_ci_env():
yield
@pytest.fixture
def geo_assistant_with_buffer_fixture():
"""Fixture with a point at (-9.1393, 38.7223) and 0.5km buffer as search_area."""
place_geojson = Feature(
type="Feature",
geometry=Point(type="Point", coordinates=[-9.1393, 38.7223]),
properties={"name": "Neighbourhood Cafe Lisbon"},
)
gdf = gpd.GeoDataFrame(
[{"geometry": ShapelyPoint(-9.1393, 38.7223)}],
crs="EPSG:4326",
)
# Convert to Web Mercator for meter-based buffering
gdf_m = gdf.to_crs(epsg=3857)
gdf_m["geometry"] = gdf_m["geometry"].buffer(500) # 0.5km = 500m
# Convert back to WGS84
gdf_buffered = gdf_m.to_crs(epsg=4326)
# Get the buffered geometry as GeoJSON
search_area_geojson = Feature(
type="Feature",
geometry=gdf_buffered.iloc[0].geometry.__geo_interface__,
properties={},
)
return GeoAssistantState(
place=place_geojson,
search_area=search_area_geojson,
messages=[],
)
async def test_get_place():
command = await get_place.ainvoke(
ToolCall(
@@ -28,3 +66,20 @@ async def test_get_place():
),
)
assert "place" in command.update
async def test_get_places_within_buffer(geo_assistant_with_buffer_fixture):
command = await get_places_within_buffer.ainvoke(
ToolCall(
name="get_places_within_buffer",
type="tool_call",
id="test_id_places_within_buffer",
args={
"place": "cafe",
"state": geo_assistant_with_buffer_fixture,
"tool_call_id": "test_id_places_within_buffer",
},
),
)
assert "places_within_buffer" in command.update