Initial commit

This commit is contained in:
Daniel Wiesmann
2025-12-04 09:36:14 +00:00
parent 9e541ac6b7
commit 1a11473421
14 changed files with 2489 additions and 0 deletions
+19
View File
@@ -0,0 +1,19 @@
name: Ruff
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v3
with:
args: check
- uses: astral-sh/ruff-action@v3
with:
args: format --check
+3
View File
@@ -205,3 +205,6 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/
# VSCode
.vscode/
+9
View File
@@ -0,0 +1,9 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
hooks:
# Run the linter
- id: ruff
args: [--fix]
# Run the formatter
- id: ruff-format
+58
View File
@@ -0,0 +1,58 @@
# Geo Assistant
A geographic assistant that helps answer questions and perform tasks related to locations and geographic data.
## Environment Setup
The project uses environment variables for configuration. Copy `.env.example` to `.env` and customize as needed:
```bash
cp .env.example .env
```
Edit `.env` to set your configuration:
- `OLLAMA_MODEL`: Model name (default: `llama3.2`)
- `OLLAMA_BASE_URL`: Ollama server URL (default: `http://localhost:11434`)
- `API_BASE_URL`: API base URL for the frontend (default: `http://localhost:8000`)
The application will automatically load these variables from the `.env` file.
## Development Setup
### Pre-commit Hooks
This project uses pre-commit hooks to ensure code quality. To set up pre-commit:
1. Install dependencies (including pre-commit):
```bash
uv sync
```
2. Install the git hooks:
```bash
uv run pre-commit install
```
Pre-commit will now automatically run ruff linting and formatting checks before each commit.
To manually run pre-commit on all files:
```bash
uv run pre-commit run --all-files
```
## Running the API
```bash
uvicorn geo_assistant.api.app:app --reload
```
The API will be available at `http://localhost:8000`.
## Running the Frontend
```bash
streamlit run src/geo_assistant/frontend/app.py
```
The frontend will be available at `http://localhost:8501`.
+35
View File
@@ -0,0 +1,35 @@
[project]
name = "geo-assistant"
version = "0.0.1"
description = "Geo Assistant"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"langgraph",
"fastapi",
"pydantic",
"geojson_pydantic",
"streamlit",
"httpx",
"uvicorn[standard]",
"langchain-ollama",
"langchain",
"python-dotenv",
"duckdb",
"shapely",
]
[dependency-groups]
dev = [
"ruff",
"pytest",
"pytest-asyncio",
"pre-commit",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/geo_assistant"]
+29
View File
@@ -0,0 +1,29 @@
import datetime
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents import create_agent
from geo_assistant.agent.state import AgentState
from geo_assistant.agent.llms import llm
from geo_assistant.tools.overture import get_overture_locations
SYSTEM_PROMPT = """
You are a helpful assistant that can answer questions and help with tasks.
You have location and division tools available to you. Only use this data if the user asks for it.
The current date and time is {now}.
"""
async def create_graph():
checkpointer = InMemorySaver()
graph = create_agent(
model=llm,
tools=[get_overture_locations], # [get_overture_locations, geocode_division],
system_prompt=SYSTEM_PROMPT.format(
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
),
state_schema=AgentState,
checkpointer=checkpointer,
)
return graph
+16
View File
@@ -0,0 +1,16 @@
import os
from dotenv import load_dotenv
from langchain_ollama import ChatOllama
# Load environment variables from .env file
load_dotenv()
# Get model name from environment variable, default to llama3.2
MODEL_NAME = os.environ.get("OLLAMA_MODEL", "llama3.2")
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434")
llm = ChatOllama(
model=MODEL_NAME,
base_url=OLLAMA_BASE_URL,
)
+10
View File
@@ -0,0 +1,10 @@
from langchain.agents import AgentState as BaseAgentState
from geojson_pydantic import FeatureCollection
from typing import Optional
from pydantic import Field
class AgentState(BaseAgentState):
feature_collection: Optional[FeatureCollection] = Field(
default=None, description="FeatureCollection to be used for the analysis"
)
+125
View File
@@ -0,0 +1,125 @@
import json
from contextlib import aclosing, asynccontextmanager
from typing import Any, AsyncGenerator, Dict
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import logging
from pydantic import UUID4
from geo_assistant.agent.graph import create_graph
from geo_assistant.agent.state import AgentState
from geo_assistant.api.schemas.chat import ChatRequestBody, ChatResponse
logger = logging.getLogger(__name__)
# Whitelist state fields that can be set by the user.
# Note that these attrs need to be pydantic Fields and
# need a description in the AgentState model.
UI_SET_FIELDS_WHITELIST = ["feature_collection", "messages"]
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.chatbot = await create_graph()
yield
app = FastAPI(title="Geo Assistant", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all HTTP methods (GET, POST, PUT, DELETE, etc.)
allow_headers=["*"], # Allows all headers
)
async def stream_chat(
ui_state_update: AgentState,
thread_id: UUID4,
chatbot: Any,
request: Request,
) -> AsyncGenerator[bytes, None]:
config: Dict[str, Any] = {
"configurable": {
"thread_id": str(thread_id),
}
}
state_updates = {}
vars_to_update = {
key: val
for key, val in ui_state_update.items()
if val and key in UI_SET_FIELDS_WHITELIST
}
logger.debug(f"State variables to update: {vars_to_update}")
ui_messages = []
for key in vars_to_update.keys():
if hasattr(AgentState, key):
field_info = getattr(AgentState, key)
description = field_info.description if field_info else f"Field {key}"
if description:
ui_messages.append(
{
"content": f"Manually selected data for field {key}: {description}",
"type": "human",
}
)
# Add UI messages to the existing messages if they exist
existing_messages = vars_to_update.get("messages", [])
vars_to_update["messages"] = ui_messages + existing_messages
state_updates.update(vars_to_update)
stream = chatbot.astream(
input=state_updates,
config=config,
stream_mode="updates",
)
try:
async with aclosing(stream):
async for update in stream:
if await request.is_disconnected():
logger.info("Client disconnected; stopping stream.")
break
agent = next(iter(update.keys()))
payload = update[agent]
if "feature_collection" not in payload: # TODO
payload["feature_collection"] = None
state_payload = AgentState(**payload)
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
line = json.dumps(resp.model_dump()) + "\n"
yield line.encode("utf-8")
except Exception as e:
logger.warning("stream_chat error: %r", e)
@app.post("/chat")
async def chat(request: ChatRequestBody, http_request: Request) -> StreamingResponse:
generator = stream_chat(
ui_state_update=request.agent_state_input,
thread_id=request.thread_id,
chatbot=http_request.app.state.chatbot,
request=http_request,
)
return StreamingResponse(
generator,
media_type="application/x-ndjson; charset=utf-8",
headers={
"Cache-Control": "no-cache",
# If you run behind nginx, this prevents buffering of the stream:
"X-Accel-Buffering": "no",
},
)
+12
View File
@@ -0,0 +1,12 @@
from pydantic import BaseModel
from geo_assistant.agent.state import AgentState
class ChatRequestBody(BaseModel):
agent_state_input: AgentState
thread_id: str
class ChatResponse(BaseModel):
thread_id: str
state: AgentState
+141
View File
@@ -0,0 +1,141 @@
import json
import os
import uuid
import httpx
import streamlit as st
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# API configuration
API_BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:8000")
st.set_page_config(page_title="Geo Assistant", page_icon="💬")
# Initialize session state
if "thread_id" not in st.session_state:
st.session_state.thread_id = str(uuid.uuid4())
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
def send_message(user_message: str, message_container):
"""Send a message to the API and stream the response."""
thread_id = st.session_state.thread_id
# Prepare request body
request_body = {
"thread_id": thread_id,
"agent_state": {
"messages": [{"type": "human", "content": user_message}],
"features": [],
},
}
# Create a placeholder for streaming response
response_placeholder = message_container.empty()
last_messages = []
try:
with httpx.stream(
"POST",
f"{API_BASE_URL}/chat",
json=request_body,
timeout=60.0,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
try:
data = json.loads(line)
state = data.get("state", {})
messages = state.get("messages", [])
# Display the latest messages
if messages:
display_parts = []
for msg in messages:
msg_type = msg.get("type", "")
content = msg.get("content", "")
if msg_type == "tool":
# Display tool messages as JSON code blocks
if isinstance(content, (dict, list)):
content_str = json.dumps(content, indent=2)
else:
content_str = str(content)
display_parts.append(
f"**Tool:**\n```json\n{content_str}\n```"
)
elif msg_type in ["ai", "assistant"]:
# Display AI messages as normal text
display_parts.append(f"**AI:** {content}")
if display_parts:
full_response = "\n\n".join(display_parts)
response_placeholder.markdown(full_response)
last_messages = messages
except json.JSONDecodeError:
continue
# Return the final messages for history
return last_messages
except httpx.HTTPError as e:
message_container.error(f"Error connecting to API: {e}")
return []
except Exception as e:
message_container.error(f"Error: {e}")
return []
# Main UI
st.title("Geo Assistant")
# Display chat history
for item in st.session_state.chat_history:
role = item["role"]
content = item["content"]
with st.chat_message(role):
if role == "assistant":
# For assistant messages, check if it's a tool message
if item.get("is_tool"):
st.code(content, language="json")
else:
st.markdown(content)
else:
st.markdown(content)
# Chat input
if prompt := st.chat_input("Type your message..."):
# Add user message to history
st.session_state.chat_history.append({"role": "user", "content": prompt})
# Send message and get response
with st.chat_message("assistant"):
final_messages = send_message(prompt, st.container())
# Add response to history
if final_messages:
for msg in final_messages:
msg_type = msg.get("type", "")
content = msg.get("content", "")
if msg_type == "tool":
if isinstance(content, (dict, list)):
content_str = json.dumps(content, indent=2)
else:
content_str = str(content)
st.session_state.chat_history.append(
{"role": "assistant", "content": content_str, "is_tool": True}
)
elif msg_type in ["ai", "assistant"]:
st.session_state.chat_history.append(
{"role": "assistant", "content": content, "is_tool": False}
)
st.rerun()
+248
View File
@@ -0,0 +1,248 @@
from typing import Optional, Annotated
import duckdb
from geojson_pydantic import Feature
from shapely import wkt
from shapely.geometry import mapping
from langchain_core.tools import tool
from langgraph.types import Command
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import InjectedToolCallId
@tool
def get_overture_locations(
area_of_interest: Feature,
place_name: Optional[str] = None,
place_type: Optional[str] = None,
overture_release: str = "2024-11-13.0",
similarity_threshold: float = 0.6,
tool_call_id: Annotated[str, InjectedToolCallId] = "",
) -> Command:
"""
Get locations from Overture Maps.
Parameters
----------
area_of_interest : Feature
Area of interest to search for locations in
place_name : str, optional
Name of the place to search for
place_type : str, optional
Type of the place to search for
overture_release : str
Overture Maps release version
similarity_threshold : float
Minimum similarity score (0-1) for fuzzy name matching
tool_call_id : str
Tool call ID
Returns
-------
Command
Command that updates state with location features
"""
con = duckdb.connect()
con.execute("INSTALL spatial;")
con.execute("LOAD spatial;")
con.execute("INSTALL httpfs;")
con.execute("LOAD httpfs;")
con.execute(
"""
CREATE OR REPLACE TABLE aoi AS
SELECT ST_GeomFromGeoJSON(?) AS geom
""",
[area_of_interest.geometry.model_dump_json()],
)
base_url = f"s3://overturemaps-us-west-2/release/{overture_release}/theme=places/type=place/*"
where_conditions = ["ST_Within(ST_GeomFromWKB(geometry), (SELECT geom FROM aoi))"]
if place_type:
where_conditions.append(f"categories.primary = '{place_type}'")
if place_name:
where_conditions.append(
f"jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) >= {similarity_threshold}"
)
where_clause = " AND ".join(where_conditions)
query = f"""
SELECT
id,
ST_AsText(ST_GeomFromWKB(geometry)) as geometry_wkt,
names.primary as name,
categories.primary as primary_category,
confidence,
websites,
phones,
addresses
FROM read_parquet('{base_url}', filename=true, hive_partitioning=1)
WHERE {where_clause}
"""
result = con.execute(query).fetchall()
columns = [desc[0] for desc in con.description]
locations = [dict(zip(columns, row)) for row in result]
# Convert locations to GeoJSON Features
features = []
for loc in locations:
# Parse WKT geometry to GeoJSON
geom_wkt = loc.get("geometry_wkt")
if geom_wkt:
shapely_geom = wkt.loads(geom_wkt)
geom_dict = mapping(shapely_geom)
# Create properties from location data
properties = {
"id": loc.get("id"),
"name": loc.get("name"),
"primary_category": loc.get("primary_category"),
"confidence": loc.get("confidence"),
"websites": loc.get("websites"),
"phones": loc.get("phones"),
"addresses": loc.get("addresses"),
}
feature = Feature(geometry=geom_dict, properties=properties)
features.append(feature)
con.close()
tool_message = f"Found {len(features)} locations matching the criteria"
return Command(
update={
"features": features,
"messages": [ToolMessage(content=tool_message, tool_call_id=tool_call_id)],
},
)
@tool
def geocode_division(
query: str,
level: Optional[str] = None,
overture_release: str = "2024-11-13.0",
similarity_threshold: float = 0.6,
limit: int = 10,
tool_call_id: Annotated[str, InjectedToolCallId] = "",
) -> Command:
"""
Geocode a place name using Overture divisions data.
Parameters
----------
query : str
Place name to search for (e.g., "San Francisco", "California", "United States")
level : str, optional
Division level to filter by. Options:
- 'country'
- 'region' (states, provinces)
- 'county' (counties, districts)
- 'locality' (cities, towns)
- 'localadmin' (local administrative areas)
- 'neighborhood'
overture_release : str
Overture Maps release version
similarity_threshold : float
Minimum similarity score (0-1) for fuzzy name matching
limit : int
Maximum number of results to return
Returns
-------
Command
Command that updates state with division features
"""
con = duckdb.connect()
con.execute("INSTALL spatial;")
con.execute("LOAD spatial;")
con.execute("INSTALL httpfs;")
con.execute("LOAD httpfs;")
base_url = f"s3://overturemaps-us-west-2/release/{overture_release}/theme=divisions/type=division/*"
where_conditions = [
f"jaro_winkler_similarity(LOWER(names.primary), LOWER('{query}')) >= {similarity_threshold}"
]
if level:
where_conditions.append(f"subtype = '{level}'")
where_clause = " AND ".join(where_conditions)
query_sql = f"""
SELECT
id,
ST_AsText(ST_GeomFromWKB(geometry)) as geometry_wkt,
names.primary as name,
names.common as common_names,
subtype as division_level,
country,
region,
hierarchies,
population,
capital,
wikidata,
sources,
jaro_winkler_similarity(LOWER(names.primary), LOWER('{query}')) as similarity_score
FROM read_parquet('{base_url}', filename=true, hive_partitioning=1)
WHERE {where_clause}
ORDER BY similarity_score DESC
LIMIT {limit}
"""
result = con.execute(query_sql).fetchall()
columns = [desc[0] for desc in con.description]
divisions = [dict(zip(columns, row)) for row in result]
# Convert divisions to GeoJSON Features
features = []
for div in divisions:
# Parse WKT geometry to GeoJSON
geom_wkt = div.get("geometry_wkt")
if geom_wkt:
shapely_geom = wkt.loads(geom_wkt)
geom_dict = mapping(shapely_geom)
# Create properties from division data
properties = {
"id": div.get("id"),
"name": div.get("name"),
"common_names": div.get("common_names"),
"division_level": div.get("division_level"),
"country": div.get("country"),
"region": div.get("region"),
"hierarchies": div.get("hierarchies"),
"population": div.get("population"),
"capital": div.get("capital"),
"wikidata": div.get("wikidata"),
"sources": div.get("sources"),
"similarity_score": div.get("similarity_score"),
}
feature = Feature(geometry=geom_dict, properties=properties)
features.append(feature)
con.close()
tool_message = f"Found {len(features)} divisions matching '{query}'"
return Command(
update={
"features": features,
"messages": [ToolMessage(content=tool_message, tool_call_id=tool_call_id)],
},
)
+45
View File
@@ -0,0 +1,45 @@
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from uuid import uuid4
from geo_assistant.api.app import app
from geo_assistant.agent.graph import create_graph
@pytest_asyncio.fixture
async def initialized_app():
"""Initialize the app's chatbot before testing"""
# Manually initialize the chatbot as the lifespan would
app.state.chatbot = await create_graph()
yield app
# Cleanup if needed
if hasattr(app.state, "chatbot"):
del app.state.chatbot
@pytest.mark.asyncio
async def test_hello_world(initialized_app):
"""Hello world test for the API"""
async with AsyncClient(
transport=ASGITransport(app=initialized_app), base_url="http://test"
) as client:
thread_id = uuid4()
response = await client.post(
"/chat",
json={
"agent_state_input": {
"messages": [{"content": "Hello, world!", "type": "human"}],
"feature_collection": None,
},
"thread_id": str(thread_id),
},
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/x-ndjson; charset=utf-8"
# Read the streaming response
content = response.text
assert content is not None
assert len(content) > 0
Generated
+1739
View File
File diff suppressed because it is too large Load Diff