Fix frontend (#15)

* Fix frontend

* Fix feature creation from overture tool
This commit is contained in:
Daniel Wiesmann
2025-12-05 10:04:38 +00:00
committed by GitHub
parent 8f0239c1c9
commit 7c97b475e4
11 changed files with 172 additions and 78 deletions
+1
View File
@@ -25,6 +25,7 @@ dependencies = [
"geopandas>=1.1.1", "geopandas>=1.1.1",
"dspy>=3.0.4", "dspy>=3.0.4",
"watchdog>=6.0.0", "watchdog>=6.0.0",
"folium>=0.15.0",
] ]
[dependency-groups] [dependency-groups]
+10 -3
View File
@@ -7,13 +7,19 @@ from geo_assistant.agent.llms import llm
from geo_assistant.tools.naip import fetch_naip_img from geo_assistant.tools.naip import fetch_naip_img
from geo_assistant.tools.overture import get_place from geo_assistant.tools.overture import get_place
from geo_assistant.tools.buffer import get_search_area from geo_assistant.tools.buffer import get_search_area
from geo_assistant.tools.summarize import summarize_sat_img
SYSTEM_PROMPT = """ SYSTEM_PROMPT = """
You are a helpful assistant that can answer questions and help with tasks. You are a helpful assistant that can answer questions and help with tasks.
You have access to the following tools: You have the following tools available to you.
- Overture location lookup tool: use this to get geographic information about locations in the US based on the user's query
- NAIP imagery fetch tool: use this to fetch NAIP aerial imagery for a given area of interest returned by the overture location lookup tool and date range (do your best to extract the date range from the user's query if provided, otherwise ask the user to specify a date range) - get_place: Get a place from the Overture Maps database
- get_search_area: Get a search area buffer in km around the place defined in the agent state
- summarize_sat_img: Summarize the contents of a satellite image using an LLM
- fetch_naip_img: A NAIP imagery fetch tool. Use this to fetch NAIP aerial imagery for a given area of interest returned by the overture location lookup tool and date range (do your best to extract the date range from the user's query if provided, otherwise ask the user to specify a date range)
For places if you have links to social media, include them in the response.
The current date and time is {now}. The current date and time is {now}.
""" """
@@ -27,6 +33,7 @@ async def create_graph():
get_place, get_place,
get_search_area, get_search_area,
fetch_naip_img, fetch_naip_img,
summarize_sat_img,
], ],
system_prompt=SYSTEM_PROMPT.format( system_prompt=SYSTEM_PROMPT.format(
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+4 -4
View File
@@ -1,12 +1,12 @@
from langchain.agents import AgentState from langchain.agents import AgentState
from geojson_pydantic import Feature from geojson_pydantic import Feature
from typing import Optional from typing_extensions import NotRequired
from pydantic import Field from pydantic import Field
class GeoAssistantState(AgentState): class GeoAssistantState(AgentState):
place: Optional[Feature] place: NotRequired[Feature | None] = None
search_area: Optional[Feature] search_area: NotRequired[Feature | None] = None
naip_png_path: Optional[str] = Field( naip_png_path: NotRequired[str | None] = Field(
default=None, description="Path to the saved NAIP RGB PNG image" default=None, description="Path to the saved NAIP RGB PNG image"
) )
+2 -7
View File
@@ -92,13 +92,8 @@ async def stream_chat(
agent = next(iter(update.keys())) agent = next(iter(update.keys()))
payload = update[agent] payload = update[agent]
if "place" not in payload: # TODO: why is this needed? state = GeoAssistantState(**payload)
payload["place"] = None resp = ChatResponse(thread_id=str(thread_id), state=state)
if "search_area" not in payload: # TODO: why is this needed?
payload["search_area"] = None
state_payload = GeoAssistantState(**payload)
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
line = json.dumps(resp.model_dump()) + "\n" line = json.dumps(resp.model_dump()) + "\n"
yield line.encode("utf-8") yield line.encode("utf-8")
+1 -1
View File
@@ -3,8 +3,8 @@ from geo_assistant.agent.state import GeoAssistantState
class ChatRequestBody(BaseModel): class ChatRequestBody(BaseModel):
agent_state_input: GeoAssistantState
thread_id: str thread_id: str
agent_state_input: GeoAssistantState
class ChatResponse(BaseModel): class ChatResponse(BaseModel):
+94 -42
View File
@@ -4,6 +4,8 @@ import uuid
import httpx import httpx
import streamlit as st import streamlit as st
import streamlit.components.v1 as components
import folium
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables from .env file # Load environment variables from .env file
@@ -14,6 +16,8 @@ API_BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:8000")
st.set_page_config(page_title="Geo Assistant", page_icon="💬") st.set_page_config(page_title="Geo Assistant", page_icon="💬")
st.title("Geo Assistant")
# Initialize session state # Initialize session state
if "thread_id" not in st.session_state: if "thread_id" not in st.session_state:
st.session_state.thread_id = str(uuid.uuid4()) st.session_state.thread_id = str(uuid.uuid4())
@@ -44,58 +48,106 @@ def stream_chat(user_message: str):
response.raise_for_status() response.raise_for_status()
for line in response.iter_lines(): for line in response.iter_lines():
print("=" * 100)
print(line)
print("=" * 100)
if not line: if not line:
continue continue
data = json.loads(line) data = json.loads(line)
print("=" * 100)
print(data)
print("=" * 100)
state = data.get("state", {}) state = data.get("state", {})
messages = state.get("messages", []) messages = state.pop("messages", [])
if not messages:
continue
for msg in messages: for msg in messages:
msg_type = msg.get("type", "") msg_type = msg.get("type", "")
content = msg.get("content", "") content = msg.get("content", "")
if not content:
continue
with st.chat_message(msg_type):
st.markdown(content)
yield msg_type, content # Check for GeoJSON features and render map if present
geojson_features = {}
for key, value in state.items():
if value and isinstance(value, dict) and value.get("type") == "Feature":
geojson_features[key] = value
# with st.chat_message("tool"):
# st.code(json.dumps(value, indent=2), language="json")
elif value:
with st.chat_message("tool"):
st.code(json.dumps(value, indent=2), language="json")
# Render map if GeoJSON features are present
if geojson_features:
# Helper function to extract coordinates from geometry
def get_coords_from_geometry(geom):
"""Extract all coordinates from a GeoJSON geometry."""
geom_type = geom.get("type", "")
coords = geom.get("coordinates", [])
if geom_type == "Point":
return [coords]
elif geom_type == "LineString":
return coords
elif geom_type == "Polygon":
return coords[0] if coords else []
elif geom_type == "MultiPoint":
return coords
elif geom_type == "MultiLineString":
return [c for line in coords for c in line]
elif geom_type == "MultiPolygon":
return [c for poly in coords for c in poly[0]] if coords else []
return []
# Calculate center from all features
all_lons, all_lats = [], []
for feature in geojson_features.values():
geom = feature.get("geometry", {})
coords = get_coords_from_geometry(geom)
for coord in coords:
if len(coord) >= 2:
all_lons.append(coord[0])
all_lats.append(coord[1])
if all_lons and all_lats:
center_lat = sum(all_lats) / len(all_lats)
center_lon = sum(all_lons) / len(all_lons)
else:
center_lat, center_lon = 0.0, 0.0
m = folium.Map(location=[center_lat, center_lon], zoom_start=10)
# Add features to map with different colors
colors = {"place": "blue", "search_area": "red"}
def make_style_function(color):
"""Create a style function with the given color."""
return lambda x: {
"fillColor": color,
"color": color,
"weight": 2,
"fillOpacity": 0.3,
}
for key, feature in geojson_features.items():
color = colors.get(key, "green")
folium.GeoJson(
feature,
style_function=make_style_function(color),
tooltip=key,
).add_to(m)
# Fit map to bounds if we have coordinates
if all_lons and all_lats:
m.fit_bounds(
[[min(all_lats), min(all_lons)], [max(all_lats), max(all_lons)]]
)
# Display the map
with st.chat_message("tool"):
st.markdown("**Map View**")
map_html = m._repr_html_()
components.html(map_html, height=400)
# Main UI
st.title("Geo Assistant")
# Display chat history
for item in st.session_state.chat_history:
role = item["role"]
content = item["content"]
with st.chat_message(role):
if role == "assistant":
# For assistant messages, check if it's a tool message
if item.get("is_tool"):
st.code(content, language="json")
else:
st.markdown(content)
else:
st.markdown(content)
# Chat input
if prompt := st.chat_input("Type your message..."): if prompt := st.chat_input("Type your message..."):
st.session_state.chat_history.append({"role": "user", "content": prompt}) with st.chat_message("user"):
st.markdown(prompt)
for msg_type, content in stream_chat(prompt): stream_chat(prompt)
if msg_type == "tool":
st.session_state.chat_history.append({"role": "tool", "content": content})
elif msg_type in ["ai", "assistant"]:
st.session_state.chat_history.append(
{"role": "assistant", "content": content}
)
st.rerun()
+2 -3
View File
@@ -6,7 +6,6 @@ from langchain_core.messages import ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from typing import Annotated from typing import Annotated
from geo_assistant.agent.state import GeoAssistantState from geo_assistant.agent.state import GeoAssistantState
from geojson_pydantic import Feature from geojson_pydantic import Feature
@@ -18,7 +17,7 @@ async def get_search_area(
) -> Command: ) -> Command:
"""Get a search area buffer in km around the place defined in the agent state.""" """Get a search area buffer in km around the place defined in the agent state."""
place_feature = state["place"] place_feature = state.get("place")
if not place_feature: if not place_feature:
return Command( return Command(
@@ -52,7 +51,7 @@ async def get_search_area(
buffer_feature = Feature( buffer_feature = Feature(
type="Feature", type="Feature",
geometry=gdf.iloc[0].geometry.__geo_interface__, geometry=gdf.iloc[0].geometry.__geo_interface__,
properties={}, properties=place_feature.properties.copy(),
) )
return Command( return Command(
+12 -16
View File
@@ -8,6 +8,7 @@ from langchain_core.messages import ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId from langchain_core.tools.base import InjectedToolCallId
from langgraph.types import Command from langgraph.types import Command
from geojson_pydantic import Feature
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
@@ -71,27 +72,22 @@ async def get_place(
geometry = json.loads(location_results[0][-1]) geometry = json.loads(location_results[0][-1])
# Create FeatureCollection feature = Feature(
feature_collection = { type="Feature",
"type": "FeatureCollection", geometry=geometry,
"features": [ properties={
{ "overture_id": location_results[0][0],
"type": "Feature", "name": location_results[0][2],
"geometry": geometry, "socials": location_results[0][4],
"properties": { },
"name": location_results[0][2], )
"overture_id": location_results[0][0],
},
}
],
}
return Command( return Command(
update={ update={
"place": feature_collection, "place": feature,
"messages": [ "messages": [
ToolMessage( ToolMessage(
content=f"Found place with Overture name: {location_results[0][2]} based on user query", content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}",
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
) )
], ],
+1 -1
View File
@@ -66,7 +66,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent()
@tool @tool
def summarize_sat_img( async def summarize_sat_img(
img_url: str, img_url: str,
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None, tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
) -> Command: ) -> Command:
+6 -1
View File
@@ -29,7 +29,12 @@ async def test_hello_world(initialized_app):
"/chat", "/chat",
json={ json={
"agent_state_input": { "agent_state_input": {
"messages": [{"content": "Hello, world!", "type": "human"}], "messages": [
{
"content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it",
"type": "human",
}
],
"place": None, "place": None,
"search_area": None, "search_area": None,
}, },
Generated
+39
View File
@@ -325,6 +325,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" }, { url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" },
] ]
[[package]]
name = "branca"
version = "0.8.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2" },
]
sdist = { url = "https://files.pythonhosted.org/packages/32/14/9d409124bda3f4ab7af3802aba07181d1fd56aa96cc4b999faea6a27a0d2/branca-0.8.2.tar.gz", hash = "sha256:e5040f4c286e973658c27de9225c1a5a7356dd0702a7c8d84c0f0dfbde388fe7", size = 27890, upload-time = "2025-10-06T10:28:20.305Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/50/fc9680058e63161f2f63165b84c957a0df1415431104c408e8104a3a18ef/branca-0.8.2-py3-none-any.whl", hash = "sha256:2ebaef3983e3312733c1ae2b793b0a8ba3e1c4edeb7598e10328505280cf2f7c", size = 26193, upload-time = "2025-10-06T10:28:19.255Z" },
]
[[package]] [[package]]
name = "cachetools" name = "cachetools"
version = "6.2.2" version = "6.2.2"
@@ -788,6 +800,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, { url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" },
] ]
[[package]]
name = "folium"
version = "0.20.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "branca" },
{ name = "jinja2" },
{ name = "numpy" },
{ name = "requests" },
{ name = "xyzservices" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c7/76/84a1b1b00ce71f9c0c44af7d80f310c02e2e583591fe7d4cb03baecd0d3f/folium-0.20.0.tar.gz", hash = "sha256:a0d78b9d5a36ba7589ca9aedbd433e84e9fcab79cd6ac213adbcff922e454cb9", size = 109932, upload-time = "2025-06-16T20:22:51.803Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/a8/5f764f333204db0390362a4356d03a43626997f26818a0e9396f1b3bd8c9/folium-0.20.0-py2.py3-none-any.whl", hash = "sha256:f0bc2a92acde20bca56367aa5c1c376c433f450608d058daebab2fc9bf8198bf", size = 113394, upload-time = "2025-06-16T20:22:50.318Z" },
]
[[package]] [[package]]
name = "fonttools" name = "fonttools"
version = "4.61.0" version = "4.61.0"
@@ -920,6 +948,7 @@ dependencies = [
{ name = "dspy" }, { name = "dspy" },
{ name = "duckdb" }, { name = "duckdb" },
{ name = "fastapi" }, { name = "fastapi" },
{ name = "folium" },
{ name = "geojson-pydantic" }, { name = "geojson-pydantic" },
{ name = "geopandas" }, { name = "geopandas" },
{ name = "httpx" }, { name = "httpx" },
@@ -953,6 +982,7 @@ requires-dist = [
{ name = "dspy", specifier = ">=3.0.4" }, { name = "dspy", specifier = ">=3.0.4" },
{ name = "duckdb" }, { name = "duckdb" },
{ name = "fastapi" }, { name = "fastapi" },
{ name = "folium", specifier = ">=0.15.0" },
{ name = "geojson-pydantic" }, { name = "geojson-pydantic" },
{ name = "geopandas", specifier = ">=1.1.1" }, { name = "geopandas", specifier = ">=1.1.1" },
{ name = "httpx" }, { name = "httpx" },
@@ -4141,6 +4171,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0f/c9/7243eb3f9eaabd1a88a5a5acadf06df2d83b100c62684b7425c6a11bcaa8/xxhash-3.6.0-cp314-cp314t-win_arm64.whl", hash = "sha256:bb79b1e63f6fd84ec778a4b1916dfe0a7c3fdb986c06addd5db3a0d413819d95", size = 28898, upload-time = "2025-10-02T14:36:17.843Z" }, { url = "https://files.pythonhosted.org/packages/0f/c9/7243eb3f9eaabd1a88a5a5acadf06df2d83b100c62684b7425c6a11bcaa8/xxhash-3.6.0-cp314-cp314t-win_arm64.whl", hash = "sha256:bb79b1e63f6fd84ec778a4b1916dfe0a7c3fdb986c06addd5db3a0d413819d95", size = 28898, upload-time = "2025-10-02T14:36:17.843Z" },
] ]
[[package]]
name = "xyzservices"
version = "2025.11.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ee/0f/022795fc1201e7c29e742a509913badb53ce0b38f64b6db859e2f6339da9/xyzservices-2025.11.0.tar.gz", hash = "sha256:2fc72b49502b25023fd71e8f532fb4beddbbf0aa124d90ea25dba44f545e17ce", size = 1135703, upload-time = "2025-11-22T11:31:51.82Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ef/5c/2c189d18d495dd0fa3f27ccc60762bbc787eed95b9b0147266e72bb76585/xyzservices-2025.11.0-py3-none-any.whl", hash = "sha256:de66a7599a8d6dad63980b77defd1d8f5a5a9cb5fc8774ea1c6e89ca7c2a3d2f", size = 93916, upload-time = "2025-11-22T11:31:50.525Z" },
]
[[package]] [[package]]
name = "yarl" name = "yarl"
version = "1.22.0" version = "1.22.0"