diff --git a/pyproject.toml b/pyproject.toml index 86a391f..921d4d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "geopandas>=1.1.1", "dspy>=3.0.4", "watchdog>=6.0.0", + "folium>=0.15.0", ] [dependency-groups] diff --git a/src/geo_assistant/agent/graph.py b/src/geo_assistant/agent/graph.py index 42df6f0..96d7fb4 100644 --- a/src/geo_assistant/agent/graph.py +++ b/src/geo_assistant/agent/graph.py @@ -7,13 +7,19 @@ from geo_assistant.agent.llms import llm from geo_assistant.tools.naip import fetch_naip_img from geo_assistant.tools.overture import get_place from geo_assistant.tools.buffer import get_search_area +from geo_assistant.tools.summarize import summarize_sat_img SYSTEM_PROMPT = """ You are a helpful assistant that can answer questions and help with tasks. -You have access to the following tools: -- Overture location lookup tool: use this to get geographic information about locations in the US based on the user's query -- NAIP imagery fetch tool: use this to fetch NAIP aerial imagery for a given area of interest returned by the overture location lookup tool and date range (do your best to extract the date range from the user's query if provided, otherwise ask the user to specify a date range) +You have the following tools available to you. + +- get_place: Get a place from the Overture Maps database +- get_search_area: Get a search area buffer in km around the place defined in the agent state +- summarize_sat_img: Summarize the contents of a satellite image using an LLM +- fetch_naip_img: A NAIP imagery fetch tool. Use this to fetch NAIP aerial imagery for a given area of interest returned by the overture location lookup tool and date range (do your best to extract the date range from the user's query if provided, otherwise ask the user to specify a date range) + +For places if you have links to social media, include them in the response. The current date and time is {now}. """ @@ -27,6 +33,7 @@ async def create_graph(): get_place, get_search_area, fetch_naip_img, + summarize_sat_img, ], system_prompt=SYSTEM_PROMPT.format( now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/src/geo_assistant/agent/state.py b/src/geo_assistant/agent/state.py index 7bc5b1d..40973ed 100644 --- a/src/geo_assistant/agent/state.py +++ b/src/geo_assistant/agent/state.py @@ -1,12 +1,12 @@ from langchain.agents import AgentState from geojson_pydantic import Feature -from typing import Optional +from typing_extensions import NotRequired from pydantic import Field class GeoAssistantState(AgentState): - place: Optional[Feature] - search_area: Optional[Feature] - naip_png_path: Optional[str] = Field( + place: NotRequired[Feature | None] = None + search_area: NotRequired[Feature | None] = None + naip_png_path: NotRequired[str | None] = Field( default=None, description="Path to the saved NAIP RGB PNG image" ) diff --git a/src/geo_assistant/api/app.py b/src/geo_assistant/api/app.py index 98f1e04..739d932 100644 --- a/src/geo_assistant/api/app.py +++ b/src/geo_assistant/api/app.py @@ -92,13 +92,8 @@ async def stream_chat( agent = next(iter(update.keys())) payload = update[agent] - if "place" not in payload: # TODO: why is this needed? - payload["place"] = None - if "search_area" not in payload: # TODO: why is this needed? - payload["search_area"] = None - state_payload = GeoAssistantState(**payload) - - resp = ChatResponse(thread_id=str(thread_id), state=state_payload) + state = GeoAssistantState(**payload) + resp = ChatResponse(thread_id=str(thread_id), state=state) line = json.dumps(resp.model_dump()) + "\n" yield line.encode("utf-8") diff --git a/src/geo_assistant/api/schemas/chat.py b/src/geo_assistant/api/schemas/chat.py index ade45c0..f757ce1 100644 --- a/src/geo_assistant/api/schemas/chat.py +++ b/src/geo_assistant/api/schemas/chat.py @@ -3,8 +3,8 @@ from geo_assistant.agent.state import GeoAssistantState class ChatRequestBody(BaseModel): - agent_state_input: GeoAssistantState thread_id: str + agent_state_input: GeoAssistantState class ChatResponse(BaseModel): diff --git a/src/geo_assistant/frontend/app.py b/src/geo_assistant/frontend/app.py index 2566dd0..b35403e 100644 --- a/src/geo_assistant/frontend/app.py +++ b/src/geo_assistant/frontend/app.py @@ -4,6 +4,8 @@ import uuid import httpx import streamlit as st +import streamlit.components.v1 as components +import folium from dotenv import load_dotenv # Load environment variables from .env file @@ -14,6 +16,8 @@ API_BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:8000") st.set_page_config(page_title="Geo Assistant", page_icon="💬") +st.title("Geo Assistant") + # Initialize session state if "thread_id" not in st.session_state: st.session_state.thread_id = str(uuid.uuid4()) @@ -44,58 +48,106 @@ def stream_chat(user_message: str): response.raise_for_status() for line in response.iter_lines(): - print("=" * 100) - print(line) - print("=" * 100) - if not line: continue data = json.loads(line) - print("=" * 100) - print(data) - print("=" * 100) state = data.get("state", {}) - messages = state.get("messages", []) - - if not messages: - continue + messages = state.pop("messages", []) for msg in messages: msg_type = msg.get("type", "") content = msg.get("content", "") + if not content: + continue + with st.chat_message(msg_type): + st.markdown(content) - yield msg_type, content + # Check for GeoJSON features and render map if present + geojson_features = {} + for key, value in state.items(): + if value and isinstance(value, dict) and value.get("type") == "Feature": + geojson_features[key] = value + # with st.chat_message("tool"): + # st.code(json.dumps(value, indent=2), language="json") + elif value: + with st.chat_message("tool"): + st.code(json.dumps(value, indent=2), language="json") + + # Render map if GeoJSON features are present + if geojson_features: + # Helper function to extract coordinates from geometry + def get_coords_from_geometry(geom): + """Extract all coordinates from a GeoJSON geometry.""" + geom_type = geom.get("type", "") + coords = geom.get("coordinates", []) + + if geom_type == "Point": + return [coords] + elif geom_type == "LineString": + return coords + elif geom_type == "Polygon": + return coords[0] if coords else [] + elif geom_type == "MultiPoint": + return coords + elif geom_type == "MultiLineString": + return [c for line in coords for c in line] + elif geom_type == "MultiPolygon": + return [c for poly in coords for c in poly[0]] if coords else [] + return [] + + # Calculate center from all features + all_lons, all_lats = [], [] + for feature in geojson_features.values(): + geom = feature.get("geometry", {}) + coords = get_coords_from_geometry(geom) + for coord in coords: + if len(coord) >= 2: + all_lons.append(coord[0]) + all_lats.append(coord[1]) + + if all_lons and all_lats: + center_lat = sum(all_lats) / len(all_lats) + center_lon = sum(all_lons) / len(all_lons) + else: + center_lat, center_lon = 0.0, 0.0 + + m = folium.Map(location=[center_lat, center_lon], zoom_start=10) + + # Add features to map with different colors + colors = {"place": "blue", "search_area": "red"} + + def make_style_function(color): + """Create a style function with the given color.""" + return lambda x: { + "fillColor": color, + "color": color, + "weight": 2, + "fillOpacity": 0.3, + } + + for key, feature in geojson_features.items(): + color = colors.get(key, "green") + folium.GeoJson( + feature, + style_function=make_style_function(color), + tooltip=key, + ).add_to(m) + + # Fit map to bounds if we have coordinates + if all_lons and all_lats: + m.fit_bounds( + [[min(all_lats), min(all_lons)], [max(all_lats), max(all_lons)]] + ) + + # Display the map + with st.chat_message("tool"): + st.markdown("**Map View**") + map_html = m._repr_html_() + components.html(map_html, height=400) -# Main UI -st.title("Geo Assistant") - -# Display chat history -for item in st.session_state.chat_history: - role = item["role"] - content = item["content"] - - with st.chat_message(role): - if role == "assistant": - # For assistant messages, check if it's a tool message - if item.get("is_tool"): - st.code(content, language="json") - else: - st.markdown(content) - else: - st.markdown(content) - -# Chat input if prompt := st.chat_input("Type your message..."): - st.session_state.chat_history.append({"role": "user", "content": prompt}) - - for msg_type, content in stream_chat(prompt): - if msg_type == "tool": - st.session_state.chat_history.append({"role": "tool", "content": content}) - elif msg_type in ["ai", "assistant"]: - st.session_state.chat_history.append( - {"role": "assistant", "content": content} - ) - - st.rerun() + with st.chat_message("user"): + st.markdown(prompt) + stream_chat(prompt) diff --git a/src/geo_assistant/tools/buffer.py b/src/geo_assistant/tools/buffer.py index 92a94e2..b72d813 100644 --- a/src/geo_assistant/tools/buffer.py +++ b/src/geo_assistant/tools/buffer.py @@ -6,7 +6,6 @@ from langchain_core.messages import ToolMessage from langchain_core.tools import tool from typing import Annotated from geo_assistant.agent.state import GeoAssistantState - from geojson_pydantic import Feature @@ -18,7 +17,7 @@ async def get_search_area( ) -> Command: """Get a search area buffer in km around the place defined in the agent state.""" - place_feature = state["place"] + place_feature = state.get("place") if not place_feature: return Command( @@ -52,7 +51,7 @@ async def get_search_area( buffer_feature = Feature( type="Feature", geometry=gdf.iloc[0].geometry.__geo_interface__, - properties={}, + properties=place_feature.properties.copy(), ) return Command( diff --git a/src/geo_assistant/tools/overture.py b/src/geo_assistant/tools/overture.py index c9b82f3..8d2c1ef 100644 --- a/src/geo_assistant/tools/overture.py +++ b/src/geo_assistant/tools/overture.py @@ -8,6 +8,7 @@ from langchain_core.messages import ToolMessage from langchain_core.tools import tool from langchain_core.tools.base import InjectedToolCallId from langgraph.types import Command +from geojson_pydantic import Feature # Load environment variables load_dotenv() @@ -71,27 +72,22 @@ async def get_place( geometry = json.loads(location_results[0][-1]) - # Create FeatureCollection - feature_collection = { - "type": "FeatureCollection", - "features": [ - { - "type": "Feature", - "geometry": geometry, - "properties": { - "name": location_results[0][2], - "overture_id": location_results[0][0], - }, - } - ], - } + feature = Feature( + type="Feature", + geometry=geometry, + properties={ + "overture_id": location_results[0][0], + "name": location_results[0][2], + "socials": location_results[0][4], + }, + ) return Command( update={ - "place": feature_collection, + "place": feature, "messages": [ ToolMessage( - content=f"Found place with Overture name: {location_results[0][2]} based on user query", + content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}", tool_call_id=tool_call_id, ) ], diff --git a/src/geo_assistant/tools/summarize.py b/src/geo_assistant/tools/summarize.py index faa415c..e0b8fe7 100644 --- a/src/geo_assistant/tools/summarize.py +++ b/src/geo_assistant/tools/summarize.py @@ -66,7 +66,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent() @tool -def summarize_sat_img( +async def summarize_sat_img( img_url: str, tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None, ) -> Command: diff --git a/tests/test_api.py b/tests/test_api.py index aba692e..d4dc955 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -29,7 +29,12 @@ async def test_hello_world(initialized_app): "/chat", json={ "agent_state_input": { - "messages": [{"content": "Hello, world!", "type": "human"}], + "messages": [ + { + "content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it", + "type": "human", + } + ], "place": None, "search_area": None, }, diff --git a/uv.lock b/uv.lock index 25b8359..7c5b34f 100644 --- a/uv.lock +++ b/uv.lock @@ -325,6 +325,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" }, ] +[[package]] +name = "branca" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/32/14/9d409124bda3f4ab7af3802aba07181d1fd56aa96cc4b999faea6a27a0d2/branca-0.8.2.tar.gz", hash = "sha256:e5040f4c286e973658c27de9225c1a5a7356dd0702a7c8d84c0f0dfbde388fe7", size = 27890, upload-time = "2025-10-06T10:28:20.305Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/50/fc9680058e63161f2f63165b84c957a0df1415431104c408e8104a3a18ef/branca-0.8.2-py3-none-any.whl", hash = "sha256:2ebaef3983e3312733c1ae2b793b0a8ba3e1c4edeb7598e10328505280cf2f7c", size = 26193, upload-time = "2025-10-06T10:28:19.255Z" }, +] + [[package]] name = "cachetools" version = "6.2.2" @@ -788,6 +800,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, ] +[[package]] +name = "folium" +version = "0.20.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "branca" }, + { name = "jinja2" }, + { name = "numpy" }, + { name = "requests" }, + { name = "xyzservices" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/76/84a1b1b00ce71f9c0c44af7d80f310c02e2e583591fe7d4cb03baecd0d3f/folium-0.20.0.tar.gz", hash = "sha256:a0d78b9d5a36ba7589ca9aedbd433e84e9fcab79cd6ac213adbcff922e454cb9", size = 109932, upload-time = "2025-06-16T20:22:51.803Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/a8/5f764f333204db0390362a4356d03a43626997f26818a0e9396f1b3bd8c9/folium-0.20.0-py2.py3-none-any.whl", hash = "sha256:f0bc2a92acde20bca56367aa5c1c376c433f450608d058daebab2fc9bf8198bf", size = 113394, upload-time = "2025-06-16T20:22:50.318Z" }, +] + [[package]] name = "fonttools" version = "4.61.0" @@ -920,6 +948,7 @@ dependencies = [ { name = "dspy" }, { name = "duckdb" }, { name = "fastapi" }, + { name = "folium" }, { name = "geojson-pydantic" }, { name = "geopandas" }, { name = "httpx" }, @@ -953,6 +982,7 @@ requires-dist = [ { name = "dspy", specifier = ">=3.0.4" }, { name = "duckdb" }, { name = "fastapi" }, + { name = "folium", specifier = ">=0.15.0" }, { name = "geojson-pydantic" }, { name = "geopandas", specifier = ">=1.1.1" }, { name = "httpx" }, @@ -4141,6 +4171,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/c9/7243eb3f9eaabd1a88a5a5acadf06df2d83b100c62684b7425c6a11bcaa8/xxhash-3.6.0-cp314-cp314t-win_arm64.whl", hash = "sha256:bb79b1e63f6fd84ec778a4b1916dfe0a7c3fdb986c06addd5db3a0d413819d95", size = 28898, upload-time = "2025-10-02T14:36:17.843Z" }, ] +[[package]] +name = "xyzservices" +version = "2025.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/0f/022795fc1201e7c29e742a509913badb53ce0b38f64b6db859e2f6339da9/xyzservices-2025.11.0.tar.gz", hash = "sha256:2fc72b49502b25023fd71e8f532fb4beddbbf0aa124d90ea25dba44f545e17ce", size = 1135703, upload-time = "2025-11-22T11:31:51.82Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/5c/2c189d18d495dd0fa3f27ccc60762bbc787eed95b9b0147266e72bb76585/xyzservices-2025.11.0-py3-none-any.whl", hash = "sha256:de66a7599a8d6dad63980b77defd1d8f5a5a9cb5fc8774ea1c6e89ca7c2a3d2f", size = 93916, upload-time = "2025-11-22T11:31:50.525Z" }, +] + [[package]] name = "yarl" version = "1.22.0"