diff --git a/src/geo_assistant/agent/graph.py b/src/geo_assistant/agent/graph.py index 1e50d9b..4ef8de8 100644 --- a/src/geo_assistant/agent/graph.py +++ b/src/geo_assistant/agent/graph.py @@ -8,6 +8,7 @@ from geo_assistant.agent.state import GeoAssistantState from geo_assistant.tools import ( fetch_naip_img, get_place, + get_places_within_buffer, get_search_area, summarize_sat_img, ) @@ -19,10 +20,11 @@ You have the following tools available to you. - get_place: Get a place from the Overture Maps database - get_search_area: Get a search area buffer in km around the place defined in the agent state +- get_places_within_buffer: Get places from the Overture Maps database within the search area defined in the agent state - summarize_sat_img: Summarize the contents of a satellite image using an LLM - fetch_naip_img: A NAIP imagery fetch tool. Use this to fetch NAIP aerial imagery for a given area of interest returned by the overture location lookup tool and date range (do your best to extract the date range from the user's query if provided, otherwise ask the user to specify a date range) -For places if you have links to social media, include them in the response. +Do not use background knowledge, only use the tools above to answer questions. The current date and time is {now}. """ @@ -35,6 +37,7 @@ async def create_graph(): tools=[ get_place, get_search_area, + get_places_within_buffer, fetch_naip_img, summarize_sat_img, ], diff --git a/src/geo_assistant/agent/state.py b/src/geo_assistant/agent/state.py index 8d0aca9..b727944 100644 --- a/src/geo_assistant/agent/state.py +++ b/src/geo_assistant/agent/state.py @@ -1,6 +1,6 @@ from typing import NotRequired -from geojson_pydantic import Feature +from geojson_pydantic import Feature, FeatureCollection from langchain.agents import AgentState from pydantic import Field @@ -8,6 +8,7 @@ from pydantic import Field class GeoAssistantState(AgentState): place: NotRequired[Feature | None] = None search_area: NotRequired[Feature | None] = None + places_within_buffer: NotRequired[FeatureCollection | None] = None naip_img_bytes: NotRequired[str | None] = Field( default=None, description="Base 64 encoded bytes str of the saved NAIP RGB JPEG image", diff --git a/src/geo_assistant/frontend/app.py b/src/geo_assistant/frontend/app.py index 8bc5c89..5f188cc 100644 --- a/src/geo_assistant/frontend/app.py +++ b/src/geo_assistant/frontend/app.py @@ -67,7 +67,11 @@ def stream_chat(user_message: str): # Check for GeoJSON features and render map if present geojson_features = {} for key, value in state.items(): - if value and isinstance(value, dict) and value.get("type") == "Feature": + if ( + value + and isinstance(value, dict) + and value.get("type") in ["Feature", "FeatureCollection"] + ): geojson_features[key] = value elif value and isinstance(value, str) and key == "naip_img_bytes": # Handle base64-encoded jpeg data @@ -107,13 +111,20 @@ def stream_chat(user_message: str): # Calculate center from all features all_lons, all_lats = [], [] - for feature in geojson_features.values(): - geom = feature.get("geometry", {}) - coords = get_coords_from_geometry(geom) - for coord in coords: - if len(coord) >= 2: - all_lons.append(coord[0]) - all_lats.append(coord[1]) + for feature_data in geojson_features.values(): + # Handle both Feature and FeatureCollection + if feature_data.get("type") == "FeatureCollection": + features = feature_data.get("features", []) + else: + features = [feature_data] + + for feature in features: + geom = feature.get("geometry", {}) + coords = get_coords_from_geometry(geom) + for coord in coords: + if len(coord) >= 2: + all_lons.append(coord[0]) + all_lats.append(coord[1]) if all_lons and all_lats: center_lat = sum(all_lats) / len(all_lats) @@ -124,7 +135,11 @@ def stream_chat(user_message: str): m = folium.Map(location=[center_lat, center_lon], zoom_start=10) # Add features to map with different colors - colors = {"place": "blue", "search_area": "red"} + colors = { + "place": "blue", + "search_area": "red", + "places_within_buffer": "green", + } def make_style_function(color): """Create a style function with the given color.""" @@ -135,13 +150,26 @@ def stream_chat(user_message: str): "fillOpacity": 0.3, } - for key, feature in geojson_features.items(): - color = colors.get(key, "green") - folium.GeoJson( - feature, - style_function=make_style_function(color), - tooltip=key, - ).add_to(m) + for key, feature_data in geojson_features.items(): + color = colors.get(key, "purple") + + # Handle both Feature and FeatureCollection + if feature_data.get("type") == "FeatureCollection": + # For FeatureCollections, add each feature with a popup showing its name + for feature in feature_data.get("features", []): + name = feature.get("properties", {}).get("name", key) + folium.GeoJson( + feature, + style_function=make_style_function(color), + tooltip=name, + ).add_to(m) + else: + # For single Features + folium.GeoJson( + feature_data, + style_function=make_style_function(color), + tooltip=key, + ).add_to(m) # Fit map to bounds if we have coordinates if all_lons and all_lats: diff --git a/src/geo_assistant/tools/__init__.py b/src/geo_assistant/tools/__init__.py index 72a57a1..c2233c1 100644 --- a/src/geo_assistant/tools/__init__.py +++ b/src/geo_assistant/tools/__init__.py @@ -1,6 +1,12 @@ from geo_assistant.tools.buffer import get_search_area from geo_assistant.tools.naip import fetch_naip_img -from geo_assistant.tools.overture import get_place +from geo_assistant.tools.overture import get_place, get_places_within_buffer from geo_assistant.tools.summarize import summarize_sat_img -__all__ = ["fetch_naip_img", "get_place", "get_search_area", "summarize_sat_img"] +__all__ = [ + "fetch_naip_img", + "get_place", + "get_places_within_buffer", + "get_search_area", + "summarize_sat_img", +] diff --git a/src/geo_assistant/tools/overture.py b/src/geo_assistant/tools/overture.py index 8183c4d..a77a673 100644 --- a/src/geo_assistant/tools/overture.py +++ b/src/geo_assistant/tools/overture.py @@ -3,12 +3,18 @@ import os from typing import Annotated import duckdb +import geopandas as gpd +import numpy as np from dotenv import load_dotenv -from geojson_pydantic import Feature +from geojson_pydantic import Feature, FeatureCollection from langchain_core.messages import ToolMessage from langchain_core.tools import tool from langchain_core.tools.base import InjectedToolCallId +from langgraph.prebuilt import InjectedState from langgraph.types import Command +from shapely.geometry import shape + +from geo_assistant.agent.state import GeoAssistantState # Load environment variables load_dotenv() @@ -94,3 +100,128 @@ async def get_place( ], }, ) + + +def normalize_place_type(place: str) -> str: + """Normalize place type input to Overture categories.""" + place_lower = place.lower().strip() + + # Mapping of variations to canonical Overture types + mappings = { + # Restaurants + "restaurant": "restaurant", + "restaurants": "restaurant", + # Cafes/Coffee shops + "cafe": "cafe", + "cafes": "cafe", + "coffee": "cafe", + "coffee shop": "cafe", + "coffee shops": "cafe", + "coffeeshop": "cafe", + # Bars + "bar": "bar", + "bars": "bar", + "pub": "bar", + "pubs": "bar", + } + + return mappings.get(place_lower, place_lower) + + +def _format_places_within_buffer_message(gdf: gpd.GeoDataFrame) -> str: + """Format GeoDataFrame of places into a readable message.""" + count = len(gdf) + + if count == 0: + return "No places found matching your criteria." + + places_list = [] + for _, row in gdf.iterrows(): + name = row.get("name", "Unknown") + websites = row.get("websites") + + # Handle case where websites might not be present + website = ( + websites[0] + if isinstance(websites, np.ndarray) and len(websites) > 0 + else None + ) + + if website: + places_list.append(f" • {name} - {website}") + else: + places_list.append(f" • {name}") + + formatted_places = "\n".join(places_list) + return f"Found {count} places:\n{formatted_places}" + + +@tool +async def get_places_within_buffer( + place: str, + state: Annotated[GeoAssistantState, InjectedState], + tool_call_id: Annotated[str, InjectedToolCallId], +) -> Command: + """Get places from Overture Maps within user specified area and user specified Overture place type. + + Accepts: restaurant(s), cafe(s), coffee shop(s), bar(s), pub(s) - case insensitive.""" + + # Normalize the place type + place = normalize_place_type(place) + + # get bounds of buffered place + search_area = state["search_area"] + + db_connection = create_database_connection() + source = os.getenv("OVERTURE_SOURCE", "local") + if source == "s3": + data_path = os.getenv("OVERTURE_S3_PATH") + db_connection.execute("SET s3_region='us-west-2';") + else: + data_path = os.getenv("OVERTURE_LOCAL_PATH") + + places_df = db_connection.execute( + f""" + LOAD spatial; + SELECT + id, + names.primary AS name, + ST_AsGeoJSON(geometry) AS geometry, + websites, + socials, + categories + FROM read_parquet( + '{data_path}', + filename=true, + hive_partitioning=1 + ) + WHERE ST_Intersects(geometry, ST_GeomFromGeoJSON('{json.dumps(search_area.geometry.model_dump())}')) + AND categories.primary = '{place}' + LIMIT 10; + """, + ).fetchdf() + + db_connection.close() + + # Convert geometry column from GeoJSON strings to shapely geometries + places_df["geometry"] = places_df["geometry"].apply(lambda x: shape(json.loads(x))) + + # Create GeoDataFrame + gdf = gpd.GeoDataFrame(places_df, geometry="geometry", crs="EPSG:4326") + + # Convert to GeoJSON FeatureCollection and ensure no numpy arrays + feature_collection = FeatureCollection.model_validate( + json.loads(json.dumps(gdf.__geo_interface__, default=str)), + ) + + return Command( + update={ + "places_within_buffer": feature_collection, + "messages": [ + ToolMessage( + content=_format_places_within_buffer_message(gdf), + tool_call_id=tool_call_id, + ), + ], + }, + ) diff --git a/tests/tools/test_overture.py b/tests/tools/test_overture.py index 264e5a8..6fae726 100644 --- a/tests/tools/test_overture.py +++ b/tests/tools/test_overture.py @@ -1,9 +1,14 @@ import os +import geopandas as gpd import pytest +from geojson_pydantic import Feature, Point from langchain_core.tools.base import ToolCall +from shapely.geometry import Point as ShapelyPoint +from geo_assistant.agent.state import GeoAssistantState from geo_assistant.tools.overture import get_place +from src.geo_assistant.tools.overture import get_places_within_buffer @pytest.fixture(autouse=True) @@ -18,6 +23,39 @@ def setup_ci_env(): yield +@pytest.fixture +def geo_assistant_with_buffer_fixture(): + """Fixture with a point at (-9.1393, 38.7223) and 0.5km buffer as search_area.""" + place_geojson = Feature( + type="Feature", + geometry=Point(type="Point", coordinates=[-9.1393, 38.7223]), + properties={"name": "Neighbourhood Cafe Lisbon"}, + ) + gdf = gpd.GeoDataFrame( + [{"geometry": ShapelyPoint(-9.1393, 38.7223)}], + crs="EPSG:4326", + ) + + # Convert to Web Mercator for meter-based buffering + gdf_m = gdf.to_crs(epsg=3857) + gdf_m["geometry"] = gdf_m["geometry"].buffer(500) # 0.5km = 500m + + # Convert back to WGS84 + gdf_buffered = gdf_m.to_crs(epsg=4326) + + # Get the buffered geometry as GeoJSON + search_area_geojson = Feature( + type="Feature", + geometry=gdf_buffered.iloc[0].geometry.__geo_interface__, + properties={}, + ) + return GeoAssistantState( + place=place_geojson, + search_area=search_area_geojson, + messages=[], + ) + + async def test_get_place(): command = await get_place.ainvoke( ToolCall( @@ -28,3 +66,20 @@ async def test_get_place(): ), ) assert "place" in command.update + + +async def test_get_places_within_buffer(geo_assistant_with_buffer_fixture): + command = await get_places_within_buffer.ainvoke( + ToolCall( + name="get_places_within_buffer", + type="tool_call", + id="test_id_places_within_buffer", + args={ + "place": "cafe", + "state": geo_assistant_with_buffer_fixture, + "tool_call_id": "test_id_places_within_buffer", + }, + ), + ) + + assert "places_within_buffer" in command.update