Fix frontend (#15)

* Fix frontend

* Fix feature creation from overture tool
This commit is contained in:
Daniel Wiesmann
2025-12-05 10:04:38 +00:00
committed by GitHub
parent 8f0239c1c9
commit 7c97b475e4
11 changed files with 172 additions and 78 deletions
+1
View File
@@ -25,6 +25,7 @@ dependencies = [
"geopandas>=1.1.1",
"dspy>=3.0.4",
"watchdog>=6.0.0",
"folium>=0.15.0",
]
[dependency-groups]
+10 -3
View File
@@ -7,13 +7,19 @@ from geo_assistant.agent.llms import llm
from geo_assistant.tools.naip import fetch_naip_img
from geo_assistant.tools.overture import get_place
from geo_assistant.tools.buffer import get_search_area
from geo_assistant.tools.summarize import summarize_sat_img
SYSTEM_PROMPT = """
You are a helpful assistant that can answer questions and help with tasks.
You have access to the following tools:
- Overture location lookup tool: use this to get geographic information about locations in the US based on the user's query
- NAIP imagery fetch tool: use this to fetch NAIP aerial imagery for a given area of interest returned by the overture location lookup tool and date range (do your best to extract the date range from the user's query if provided, otherwise ask the user to specify a date range)
You have the following tools available to you.
- get_place: Get a place from the Overture Maps database
- get_search_area: Get a search area buffer in km around the place defined in the agent state
- summarize_sat_img: Summarize the contents of a satellite image using an LLM
- fetch_naip_img: A NAIP imagery fetch tool. Use this to fetch NAIP aerial imagery for a given area of interest returned by the overture location lookup tool and date range (do your best to extract the date range from the user's query if provided, otherwise ask the user to specify a date range)
For places if you have links to social media, include them in the response.
The current date and time is {now}.
"""
@@ -27,6 +33,7 @@ async def create_graph():
get_place,
get_search_area,
fetch_naip_img,
summarize_sat_img,
],
system_prompt=SYSTEM_PROMPT.format(
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+4 -4
View File
@@ -1,12 +1,12 @@
from langchain.agents import AgentState
from geojson_pydantic import Feature
from typing import Optional
from typing_extensions import NotRequired
from pydantic import Field
class GeoAssistantState(AgentState):
place: Optional[Feature]
search_area: Optional[Feature]
naip_png_path: Optional[str] = Field(
place: NotRequired[Feature | None] = None
search_area: NotRequired[Feature | None] = None
naip_png_path: NotRequired[str | None] = Field(
default=None, description="Path to the saved NAIP RGB PNG image"
)
+2 -7
View File
@@ -92,13 +92,8 @@ async def stream_chat(
agent = next(iter(update.keys()))
payload = update[agent]
if "place" not in payload: # TODO: why is this needed?
payload["place"] = None
if "search_area" not in payload: # TODO: why is this needed?
payload["search_area"] = None
state_payload = GeoAssistantState(**payload)
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
state = GeoAssistantState(**payload)
resp = ChatResponse(thread_id=str(thread_id), state=state)
line = json.dumps(resp.model_dump()) + "\n"
yield line.encode("utf-8")
+1 -1
View File
@@ -3,8 +3,8 @@ from geo_assistant.agent.state import GeoAssistantState
class ChatRequestBody(BaseModel):
agent_state_input: GeoAssistantState
thread_id: str
agent_state_input: GeoAssistantState
class ChatResponse(BaseModel):
+94 -42
View File
@@ -4,6 +4,8 @@ import uuid
import httpx
import streamlit as st
import streamlit.components.v1 as components
import folium
from dotenv import load_dotenv
# Load environment variables from .env file
@@ -14,6 +16,8 @@ API_BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:8000")
st.set_page_config(page_title="Geo Assistant", page_icon="💬")
st.title("Geo Assistant")
# Initialize session state
if "thread_id" not in st.session_state:
st.session_state.thread_id = str(uuid.uuid4())
@@ -44,58 +48,106 @@ def stream_chat(user_message: str):
response.raise_for_status()
for line in response.iter_lines():
print("=" * 100)
print(line)
print("=" * 100)
if not line:
continue
data = json.loads(line)
print("=" * 100)
print(data)
print("=" * 100)
state = data.get("state", {})
messages = state.get("messages", [])
if not messages:
continue
messages = state.pop("messages", [])
for msg in messages:
msg_type = msg.get("type", "")
content = msg.get("content", "")
if not content:
continue
with st.chat_message(msg_type):
st.markdown(content)
yield msg_type, content
# Check for GeoJSON features and render map if present
geojson_features = {}
for key, value in state.items():
if value and isinstance(value, dict) and value.get("type") == "Feature":
geojson_features[key] = value
# with st.chat_message("tool"):
# st.code(json.dumps(value, indent=2), language="json")
elif value:
with st.chat_message("tool"):
st.code(json.dumps(value, indent=2), language="json")
# Render map if GeoJSON features are present
if geojson_features:
# Helper function to extract coordinates from geometry
def get_coords_from_geometry(geom):
"""Extract all coordinates from a GeoJSON geometry."""
geom_type = geom.get("type", "")
coords = geom.get("coordinates", [])
if geom_type == "Point":
return [coords]
elif geom_type == "LineString":
return coords
elif geom_type == "Polygon":
return coords[0] if coords else []
elif geom_type == "MultiPoint":
return coords
elif geom_type == "MultiLineString":
return [c for line in coords for c in line]
elif geom_type == "MultiPolygon":
return [c for poly in coords for c in poly[0]] if coords else []
return []
# Calculate center from all features
all_lons, all_lats = [], []
for feature in geojson_features.values():
geom = feature.get("geometry", {})
coords = get_coords_from_geometry(geom)
for coord in coords:
if len(coord) >= 2:
all_lons.append(coord[0])
all_lats.append(coord[1])
if all_lons and all_lats:
center_lat = sum(all_lats) / len(all_lats)
center_lon = sum(all_lons) / len(all_lons)
else:
center_lat, center_lon = 0.0, 0.0
m = folium.Map(location=[center_lat, center_lon], zoom_start=10)
# Add features to map with different colors
colors = {"place": "blue", "search_area": "red"}
def make_style_function(color):
"""Create a style function with the given color."""
return lambda x: {
"fillColor": color,
"color": color,
"weight": 2,
"fillOpacity": 0.3,
}
for key, feature in geojson_features.items():
color = colors.get(key, "green")
folium.GeoJson(
feature,
style_function=make_style_function(color),
tooltip=key,
).add_to(m)
# Fit map to bounds if we have coordinates
if all_lons and all_lats:
m.fit_bounds(
[[min(all_lats), min(all_lons)], [max(all_lats), max(all_lons)]]
)
# Display the map
with st.chat_message("tool"):
st.markdown("**Map View**")
map_html = m._repr_html_()
components.html(map_html, height=400)
# Main UI
st.title("Geo Assistant")
# Display chat history
for item in st.session_state.chat_history:
role = item["role"]
content = item["content"]
with st.chat_message(role):
if role == "assistant":
# For assistant messages, check if it's a tool message
if item.get("is_tool"):
st.code(content, language="json")
else:
st.markdown(content)
else:
st.markdown(content)
# Chat input
if prompt := st.chat_input("Type your message..."):
st.session_state.chat_history.append({"role": "user", "content": prompt})
for msg_type, content in stream_chat(prompt):
if msg_type == "tool":
st.session_state.chat_history.append({"role": "tool", "content": content})
elif msg_type in ["ai", "assistant"]:
st.session_state.chat_history.append(
{"role": "assistant", "content": content}
)
st.rerun()
with st.chat_message("user"):
st.markdown(prompt)
stream_chat(prompt)
+2 -3
View File
@@ -6,7 +6,6 @@ from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from typing import Annotated
from geo_assistant.agent.state import GeoAssistantState
from geojson_pydantic import Feature
@@ -18,7 +17,7 @@ async def get_search_area(
) -> Command:
"""Get a search area buffer in km around the place defined in the agent state."""
place_feature = state["place"]
place_feature = state.get("place")
if not place_feature:
return Command(
@@ -52,7 +51,7 @@ async def get_search_area(
buffer_feature = Feature(
type="Feature",
geometry=gdf.iloc[0].geometry.__geo_interface__,
properties={},
properties=place_feature.properties.copy(),
)
return Command(
+12 -16
View File
@@ -8,6 +8,7 @@ from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.types import Command
from geojson_pydantic import Feature
# Load environment variables
load_dotenv()
@@ -71,27 +72,22 @@ async def get_place(
geometry = json.loads(location_results[0][-1])
# Create FeatureCollection
feature_collection = {
"type": "FeatureCollection",
"features": [
{
"type": "Feature",
"geometry": geometry,
"properties": {
"name": location_results[0][2],
"overture_id": location_results[0][0],
},
}
],
}
feature = Feature(
type="Feature",
geometry=geometry,
properties={
"overture_id": location_results[0][0],
"name": location_results[0][2],
"socials": location_results[0][4],
},
)
return Command(
update={
"place": feature_collection,
"place": feature,
"messages": [
ToolMessage(
content=f"Found place with Overture name: {location_results[0][2]} based on user query",
content=f"Found place with Overture name: {location_results[0][2]} based on user query. Socials: {location_results[0][4]}",
tool_call_id=tool_call_id,
)
],
+1 -1
View File
@@ -66,7 +66,7 @@ _SUMMARIZER_AGENT = SatImgSummaryAgent()
@tool
def summarize_sat_img(
async def summarize_sat_img(
img_url: str,
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
) -> Command:
+6 -1
View File
@@ -29,7 +29,12 @@ async def test_hello_world(initialized_app):
"/chat",
json={
"agent_state_input": {
"messages": [{"content": "Hello, world!", "type": "human"}],
"messages": [
{
"content": "Find the Neighbourhood Cafe in Lisbon and buffer 0.5km around it",
"type": "human",
}
],
"place": None,
"search_area": None,
},
Generated
+39
View File
@@ -325,6 +325,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" },
]
[[package]]
name = "branca"
version = "0.8.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2" },
]
sdist = { url = "https://files.pythonhosted.org/packages/32/14/9d409124bda3f4ab7af3802aba07181d1fd56aa96cc4b999faea6a27a0d2/branca-0.8.2.tar.gz", hash = "sha256:e5040f4c286e973658c27de9225c1a5a7356dd0702a7c8d84c0f0dfbde388fe7", size = 27890, upload-time = "2025-10-06T10:28:20.305Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/50/fc9680058e63161f2f63165b84c957a0df1415431104c408e8104a3a18ef/branca-0.8.2-py3-none-any.whl", hash = "sha256:2ebaef3983e3312733c1ae2b793b0a8ba3e1c4edeb7598e10328505280cf2f7c", size = 26193, upload-time = "2025-10-06T10:28:19.255Z" },
]
[[package]]
name = "cachetools"
version = "6.2.2"
@@ -788,6 +800,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" },
]
[[package]]
name = "folium"
version = "0.20.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "branca" },
{ name = "jinja2" },
{ name = "numpy" },
{ name = "requests" },
{ name = "xyzservices" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c7/76/84a1b1b00ce71f9c0c44af7d80f310c02e2e583591fe7d4cb03baecd0d3f/folium-0.20.0.tar.gz", hash = "sha256:a0d78b9d5a36ba7589ca9aedbd433e84e9fcab79cd6ac213adbcff922e454cb9", size = 109932, upload-time = "2025-06-16T20:22:51.803Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/a8/5f764f333204db0390362a4356d03a43626997f26818a0e9396f1b3bd8c9/folium-0.20.0-py2.py3-none-any.whl", hash = "sha256:f0bc2a92acde20bca56367aa5c1c376c433f450608d058daebab2fc9bf8198bf", size = 113394, upload-time = "2025-06-16T20:22:50.318Z" },
]
[[package]]
name = "fonttools"
version = "4.61.0"
@@ -920,6 +948,7 @@ dependencies = [
{ name = "dspy" },
{ name = "duckdb" },
{ name = "fastapi" },
{ name = "folium" },
{ name = "geojson-pydantic" },
{ name = "geopandas" },
{ name = "httpx" },
@@ -953,6 +982,7 @@ requires-dist = [
{ name = "dspy", specifier = ">=3.0.4" },
{ name = "duckdb" },
{ name = "fastapi" },
{ name = "folium", specifier = ">=0.15.0" },
{ name = "geojson-pydantic" },
{ name = "geopandas", specifier = ">=1.1.1" },
{ name = "httpx" },
@@ -4141,6 +4171,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0f/c9/7243eb3f9eaabd1a88a5a5acadf06df2d83b100c62684b7425c6a11bcaa8/xxhash-3.6.0-cp314-cp314t-win_arm64.whl", hash = "sha256:bb79b1e63f6fd84ec778a4b1916dfe0a7c3fdb986c06addd5db3a0d413819d95", size = 28898, upload-time = "2025-10-02T14:36:17.843Z" },
]
[[package]]
name = "xyzservices"
version = "2025.11.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ee/0f/022795fc1201e7c29e742a509913badb53ce0b38f64b6db859e2f6339da9/xyzservices-2025.11.0.tar.gz", hash = "sha256:2fc72b49502b25023fd71e8f532fb4beddbbf0aa124d90ea25dba44f545e17ce", size = 1135703, upload-time = "2025-11-22T11:31:51.82Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ef/5c/2c189d18d495dd0fa3f27ccc60762bbc787eed95b9b0147266e72bb76585/xyzservices-2025.11.0-py3-none-any.whl", hash = "sha256:de66a7599a8d6dad63980b77defd1d8f5a5a9cb5fc8774ea1c6e89ca7c2a3d2f", size = 93916, upload-time = "2025-11-22T11:31:50.525Z" },
]
[[package]]
name = "yarl"
version = "1.22.0"