mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Initial commit
This commit is contained in:
@@ -0,0 +1,19 @@
|
|||||||
|
name: Ruff
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ruff:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: astral-sh/ruff-action@v3
|
||||||
|
with:
|
||||||
|
args: check
|
||||||
|
- uses: astral-sh/ruff-action@v3
|
||||||
|
with:
|
||||||
|
args: format --check
|
||||||
@@ -205,3 +205,6 @@ cython_debug/
|
|||||||
marimo/_static/
|
marimo/_static/
|
||||||
marimo/_lsp/
|
marimo/_lsp/
|
||||||
__marimo__/
|
__marimo__/
|
||||||
|
|
||||||
|
# VSCode
|
||||||
|
.vscode/
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.6.9
|
||||||
|
hooks:
|
||||||
|
# Run the linter
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix]
|
||||||
|
# Run the formatter
|
||||||
|
- id: ruff-format
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
# Geo Assistant
|
||||||
|
|
||||||
|
A geographic assistant that helps answer questions and perform tasks related to locations and geographic data.
|
||||||
|
|
||||||
|
## Environment Setup
|
||||||
|
|
||||||
|
The project uses environment variables for configuration. Copy `.env.example` to `.env` and customize as needed:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
```
|
||||||
|
|
||||||
|
Edit `.env` to set your configuration:
|
||||||
|
|
||||||
|
- `OLLAMA_MODEL`: Model name (default: `llama3.2`)
|
||||||
|
- `OLLAMA_BASE_URL`: Ollama server URL (default: `http://localhost:11434`)
|
||||||
|
- `API_BASE_URL`: API base URL for the frontend (default: `http://localhost:8000`)
|
||||||
|
|
||||||
|
The application will automatically load these variables from the `.env` file.
|
||||||
|
|
||||||
|
## Development Setup
|
||||||
|
|
||||||
|
### Pre-commit Hooks
|
||||||
|
|
||||||
|
This project uses pre-commit hooks to ensure code quality. To set up pre-commit:
|
||||||
|
|
||||||
|
1. Install dependencies (including pre-commit):
|
||||||
|
```bash
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install the git hooks:
|
||||||
|
```bash
|
||||||
|
uv run pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
Pre-commit will now automatically run ruff linting and formatting checks before each commit.
|
||||||
|
|
||||||
|
To manually run pre-commit on all files:
|
||||||
|
```bash
|
||||||
|
uv run pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the API
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn geo_assistant.api.app:app --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
The API will be available at `http://localhost:8000`.
|
||||||
|
|
||||||
|
## Running the Frontend
|
||||||
|
|
||||||
|
```bash
|
||||||
|
streamlit run src/geo_assistant/frontend/app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The frontend will be available at `http://localhost:8501`.
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
[project]
|
||||||
|
name = "geo-assistant"
|
||||||
|
version = "0.0.1"
|
||||||
|
description = "Geo Assistant"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.13"
|
||||||
|
dependencies = [
|
||||||
|
"langgraph",
|
||||||
|
"fastapi",
|
||||||
|
"pydantic",
|
||||||
|
"geojson_pydantic",
|
||||||
|
"streamlit",
|
||||||
|
"httpx",
|
||||||
|
"uvicorn[standard]",
|
||||||
|
"langchain-ollama",
|
||||||
|
"langchain",
|
||||||
|
"python-dotenv",
|
||||||
|
"duckdb",
|
||||||
|
"shapely",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"ruff",
|
||||||
|
"pytest",
|
||||||
|
"pytest-asyncio",
|
||||||
|
"pre-commit",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/geo_assistant"]
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
import datetime
|
||||||
|
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from geo_assistant.agent.state import AgentState
|
||||||
|
from geo_assistant.agent.llms import llm
|
||||||
|
from geo_assistant.tools.overture import get_overture_locations
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """
|
||||||
|
You are a helpful assistant that can answer questions and help with tasks.
|
||||||
|
|
||||||
|
You have location and division tools available to you. Only use this data if the user asks for it.
|
||||||
|
|
||||||
|
The current date and time is {now}.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def create_graph():
|
||||||
|
checkpointer = InMemorySaver()
|
||||||
|
graph = create_agent(
|
||||||
|
model=llm,
|
||||||
|
tools=[get_overture_locations], # [get_overture_locations, geocode_division],
|
||||||
|
system_prompt=SYSTEM_PROMPT.format(
|
||||||
|
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
),
|
||||||
|
state_schema=AgentState,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
)
|
||||||
|
return graph
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from langchain_ollama import ChatOllama
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Get model name from environment variable, default to llama3.2
|
||||||
|
MODEL_NAME = os.environ.get("OLLAMA_MODEL", "llama3.2")
|
||||||
|
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434")
|
||||||
|
|
||||||
|
llm = ChatOllama(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
base_url=OLLAMA_BASE_URL,
|
||||||
|
)
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
from langchain.agents import AgentState as BaseAgentState
|
||||||
|
from geojson_pydantic import FeatureCollection
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
|
class AgentState(BaseAgentState):
|
||||||
|
feature_collection: Optional[FeatureCollection] = Field(
|
||||||
|
default=None, description="FeatureCollection to be used for the analysis"
|
||||||
|
)
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
import json
|
||||||
|
from contextlib import aclosing, asynccontextmanager
|
||||||
|
from typing import Any, AsyncGenerator, Dict
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import logging
|
||||||
|
from pydantic import UUID4
|
||||||
|
|
||||||
|
from geo_assistant.agent.graph import create_graph
|
||||||
|
from geo_assistant.agent.state import AgentState
|
||||||
|
from geo_assistant.api.schemas.chat import ChatRequestBody, ChatResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Whitelist state fields that can be set by the user.
|
||||||
|
# Note that these attrs need to be pydantic Fields and
|
||||||
|
# need a description in the AgentState model.
|
||||||
|
UI_SET_FIELDS_WHITELIST = ["feature_collection", "messages"]
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
app.state.chatbot = await create_graph()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Geo Assistant", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # Allows all origins
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"], # Allows all HTTP methods (GET, POST, PUT, DELETE, etc.)
|
||||||
|
allow_headers=["*"], # Allows all headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
ui_state_update: AgentState,
|
||||||
|
thread_id: UUID4,
|
||||||
|
chatbot: Any,
|
||||||
|
request: Request,
|
||||||
|
) -> AsyncGenerator[bytes, None]:
|
||||||
|
config: Dict[str, Any] = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": str(thread_id),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
state_updates = {}
|
||||||
|
|
||||||
|
vars_to_update = {
|
||||||
|
key: val
|
||||||
|
for key, val in ui_state_update.items()
|
||||||
|
if val and key in UI_SET_FIELDS_WHITELIST
|
||||||
|
}
|
||||||
|
logger.debug(f"State variables to update: {vars_to_update}")
|
||||||
|
|
||||||
|
ui_messages = []
|
||||||
|
for key in vars_to_update.keys():
|
||||||
|
if hasattr(AgentState, key):
|
||||||
|
field_info = getattr(AgentState, key)
|
||||||
|
description = field_info.description if field_info else f"Field {key}"
|
||||||
|
if description:
|
||||||
|
ui_messages.append(
|
||||||
|
{
|
||||||
|
"content": f"Manually selected data for field {key}: {description}",
|
||||||
|
"type": "human",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add UI messages to the existing messages if they exist
|
||||||
|
existing_messages = vars_to_update.get("messages", [])
|
||||||
|
vars_to_update["messages"] = ui_messages + existing_messages
|
||||||
|
|
||||||
|
state_updates.update(vars_to_update)
|
||||||
|
|
||||||
|
stream = chatbot.astream(
|
||||||
|
input=state_updates,
|
||||||
|
config=config,
|
||||||
|
stream_mode="updates",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aclosing(stream):
|
||||||
|
async for update in stream:
|
||||||
|
if await request.is_disconnected():
|
||||||
|
logger.info("Client disconnected; stopping stream.")
|
||||||
|
break
|
||||||
|
|
||||||
|
agent = next(iter(update.keys()))
|
||||||
|
payload = update[agent]
|
||||||
|
if "feature_collection" not in payload: # TODO
|
||||||
|
payload["feature_collection"] = None
|
||||||
|
state_payload = AgentState(**payload)
|
||||||
|
|
||||||
|
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
|
||||||
|
|
||||||
|
line = json.dumps(resp.model_dump()) + "\n"
|
||||||
|
yield line.encode("utf-8")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("stream_chat error: %r", e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/chat")
|
||||||
|
async def chat(request: ChatRequestBody, http_request: Request) -> StreamingResponse:
|
||||||
|
generator = stream_chat(
|
||||||
|
ui_state_update=request.agent_state_input,
|
||||||
|
thread_id=request.thread_id,
|
||||||
|
chatbot=http_request.app.state.chatbot,
|
||||||
|
request=http_request,
|
||||||
|
)
|
||||||
|
return StreamingResponse(
|
||||||
|
generator,
|
||||||
|
media_type="application/x-ndjson; charset=utf-8",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
# If you run behind nginx, this prevents buffering of the stream:
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from geo_assistant.agent.state import AgentState
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequestBody(BaseModel):
|
||||||
|
agent_state_input: AgentState
|
||||||
|
thread_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
thread_id: str
|
||||||
|
state: AgentState
|
||||||
@@ -0,0 +1,141 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import streamlit as st
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# API configuration
|
||||||
|
API_BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:8000")
|
||||||
|
|
||||||
|
st.set_page_config(page_title="Geo Assistant", page_icon="💬")
|
||||||
|
|
||||||
|
# Initialize session state
|
||||||
|
if "thread_id" not in st.session_state:
|
||||||
|
st.session_state.thread_id = str(uuid.uuid4())
|
||||||
|
if "chat_history" not in st.session_state:
|
||||||
|
st.session_state.chat_history = []
|
||||||
|
|
||||||
|
|
||||||
|
def send_message(user_message: str, message_container):
|
||||||
|
"""Send a message to the API and stream the response."""
|
||||||
|
thread_id = st.session_state.thread_id
|
||||||
|
|
||||||
|
# Prepare request body
|
||||||
|
request_body = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"agent_state": {
|
||||||
|
"messages": [{"type": "human", "content": user_message}],
|
||||||
|
"features": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a placeholder for streaming response
|
||||||
|
response_placeholder = message_container.empty()
|
||||||
|
last_messages = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with httpx.stream(
|
||||||
|
"POST",
|
||||||
|
f"{API_BASE_URL}/chat",
|
||||||
|
json=request_body,
|
||||||
|
timeout=60.0,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
state = data.get("state", {})
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
|
||||||
|
# Display the latest messages
|
||||||
|
if messages:
|
||||||
|
display_parts = []
|
||||||
|
for msg in messages:
|
||||||
|
msg_type = msg.get("type", "")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
|
||||||
|
if msg_type == "tool":
|
||||||
|
# Display tool messages as JSON code blocks
|
||||||
|
if isinstance(content, (dict, list)):
|
||||||
|
content_str = json.dumps(content, indent=2)
|
||||||
|
else:
|
||||||
|
content_str = str(content)
|
||||||
|
display_parts.append(
|
||||||
|
f"**Tool:**\n```json\n{content_str}\n```"
|
||||||
|
)
|
||||||
|
elif msg_type in ["ai", "assistant"]:
|
||||||
|
# Display AI messages as normal text
|
||||||
|
display_parts.append(f"**AI:** {content}")
|
||||||
|
|
||||||
|
if display_parts:
|
||||||
|
full_response = "\n\n".join(display_parts)
|
||||||
|
response_placeholder.markdown(full_response)
|
||||||
|
last_messages = messages
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Return the final messages for history
|
||||||
|
return last_messages
|
||||||
|
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
message_container.error(f"Error connecting to API: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
message_container.error(f"Error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# Main UI
|
||||||
|
st.title("Geo Assistant")
|
||||||
|
|
||||||
|
# Display chat history
|
||||||
|
for item in st.session_state.chat_history:
|
||||||
|
role = item["role"]
|
||||||
|
content = item["content"]
|
||||||
|
|
||||||
|
with st.chat_message(role):
|
||||||
|
if role == "assistant":
|
||||||
|
# For assistant messages, check if it's a tool message
|
||||||
|
if item.get("is_tool"):
|
||||||
|
st.code(content, language="json")
|
||||||
|
else:
|
||||||
|
st.markdown(content)
|
||||||
|
else:
|
||||||
|
st.markdown(content)
|
||||||
|
|
||||||
|
# Chat input
|
||||||
|
if prompt := st.chat_input("Type your message..."):
|
||||||
|
# Add user message to history
|
||||||
|
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
# Send message and get response
|
||||||
|
with st.chat_message("assistant"):
|
||||||
|
final_messages = send_message(prompt, st.container())
|
||||||
|
|
||||||
|
# Add response to history
|
||||||
|
if final_messages:
|
||||||
|
for msg in final_messages:
|
||||||
|
msg_type = msg.get("type", "")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
|
||||||
|
if msg_type == "tool":
|
||||||
|
if isinstance(content, (dict, list)):
|
||||||
|
content_str = json.dumps(content, indent=2)
|
||||||
|
else:
|
||||||
|
content_str = str(content)
|
||||||
|
st.session_state.chat_history.append(
|
||||||
|
{"role": "assistant", "content": content_str, "is_tool": True}
|
||||||
|
)
|
||||||
|
elif msg_type in ["ai", "assistant"]:
|
||||||
|
st.session_state.chat_history.append(
|
||||||
|
{"role": "assistant", "content": content, "is_tool": False}
|
||||||
|
)
|
||||||
|
|
||||||
|
st.rerun()
|
||||||
@@ -0,0 +1,248 @@
|
|||||||
|
from typing import Optional, Annotated
|
||||||
|
import duckdb
|
||||||
|
from geojson_pydantic import Feature
|
||||||
|
from shapely import wkt
|
||||||
|
from shapely.geometry import mapping
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langgraph.types import Command
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langchain_core.tools.base import InjectedToolCallId
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_overture_locations(
|
||||||
|
area_of_interest: Feature,
|
||||||
|
place_name: Optional[str] = None,
|
||||||
|
place_type: Optional[str] = None,
|
||||||
|
overture_release: str = "2024-11-13.0",
|
||||||
|
similarity_threshold: float = 0.6,
|
||||||
|
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||||
|
) -> Command:
|
||||||
|
"""
|
||||||
|
Get locations from Overture Maps.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
area_of_interest : Feature
|
||||||
|
Area of interest to search for locations in
|
||||||
|
place_name : str, optional
|
||||||
|
Name of the place to search for
|
||||||
|
place_type : str, optional
|
||||||
|
Type of the place to search for
|
||||||
|
overture_release : str
|
||||||
|
Overture Maps release version
|
||||||
|
similarity_threshold : float
|
||||||
|
Minimum similarity score (0-1) for fuzzy name matching
|
||||||
|
tool_call_id : str
|
||||||
|
Tool call ID
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Command
|
||||||
|
Command that updates state with location features
|
||||||
|
"""
|
||||||
|
|
||||||
|
con = duckdb.connect()
|
||||||
|
|
||||||
|
con.execute("INSTALL spatial;")
|
||||||
|
con.execute("LOAD spatial;")
|
||||||
|
|
||||||
|
con.execute("INSTALL httpfs;")
|
||||||
|
con.execute("LOAD httpfs;")
|
||||||
|
|
||||||
|
con.execute(
|
||||||
|
"""
|
||||||
|
CREATE OR REPLACE TABLE aoi AS
|
||||||
|
SELECT ST_GeomFromGeoJSON(?) AS geom
|
||||||
|
""",
|
||||||
|
[area_of_interest.geometry.model_dump_json()],
|
||||||
|
)
|
||||||
|
|
||||||
|
base_url = f"s3://overturemaps-us-west-2/release/{overture_release}/theme=places/type=place/*"
|
||||||
|
|
||||||
|
where_conditions = ["ST_Within(ST_GeomFromWKB(geometry), (SELECT geom FROM aoi))"]
|
||||||
|
|
||||||
|
if place_type:
|
||||||
|
where_conditions.append(f"categories.primary = '{place_type}'")
|
||||||
|
|
||||||
|
if place_name:
|
||||||
|
where_conditions.append(
|
||||||
|
f"jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) >= {similarity_threshold}"
|
||||||
|
)
|
||||||
|
|
||||||
|
where_clause = " AND ".join(where_conditions)
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
ST_AsText(ST_GeomFromWKB(geometry)) as geometry_wkt,
|
||||||
|
names.primary as name,
|
||||||
|
categories.primary as primary_category,
|
||||||
|
confidence,
|
||||||
|
websites,
|
||||||
|
phones,
|
||||||
|
addresses
|
||||||
|
FROM read_parquet('{base_url}', filename=true, hive_partitioning=1)
|
||||||
|
WHERE {where_clause}
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = con.execute(query).fetchall()
|
||||||
|
columns = [desc[0] for desc in con.description]
|
||||||
|
|
||||||
|
locations = [dict(zip(columns, row)) for row in result]
|
||||||
|
|
||||||
|
# Convert locations to GeoJSON Features
|
||||||
|
features = []
|
||||||
|
for loc in locations:
|
||||||
|
# Parse WKT geometry to GeoJSON
|
||||||
|
geom_wkt = loc.get("geometry_wkt")
|
||||||
|
if geom_wkt:
|
||||||
|
shapely_geom = wkt.loads(geom_wkt)
|
||||||
|
geom_dict = mapping(shapely_geom)
|
||||||
|
|
||||||
|
# Create properties from location data
|
||||||
|
properties = {
|
||||||
|
"id": loc.get("id"),
|
||||||
|
"name": loc.get("name"),
|
||||||
|
"primary_category": loc.get("primary_category"),
|
||||||
|
"confidence": loc.get("confidence"),
|
||||||
|
"websites": loc.get("websites"),
|
||||||
|
"phones": loc.get("phones"),
|
||||||
|
"addresses": loc.get("addresses"),
|
||||||
|
}
|
||||||
|
|
||||||
|
feature = Feature(geometry=geom_dict, properties=properties)
|
||||||
|
features.append(feature)
|
||||||
|
|
||||||
|
con.close()
|
||||||
|
|
||||||
|
tool_message = f"Found {len(features)} locations matching the criteria"
|
||||||
|
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"features": features,
|
||||||
|
"messages": [ToolMessage(content=tool_message, tool_call_id=tool_call_id)],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def geocode_division(
|
||||||
|
query: str,
|
||||||
|
level: Optional[str] = None,
|
||||||
|
overture_release: str = "2024-11-13.0",
|
||||||
|
similarity_threshold: float = 0.6,
|
||||||
|
limit: int = 10,
|
||||||
|
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||||
|
) -> Command:
|
||||||
|
"""
|
||||||
|
Geocode a place name using Overture divisions data.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : str
|
||||||
|
Place name to search for (e.g., "San Francisco", "California", "United States")
|
||||||
|
level : str, optional
|
||||||
|
Division level to filter by. Options:
|
||||||
|
- 'country'
|
||||||
|
- 'region' (states, provinces)
|
||||||
|
- 'county' (counties, districts)
|
||||||
|
- 'locality' (cities, towns)
|
||||||
|
- 'localadmin' (local administrative areas)
|
||||||
|
- 'neighborhood'
|
||||||
|
overture_release : str
|
||||||
|
Overture Maps release version
|
||||||
|
similarity_threshold : float
|
||||||
|
Minimum similarity score (0-1) for fuzzy name matching
|
||||||
|
limit : int
|
||||||
|
Maximum number of results to return
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Command
|
||||||
|
Command that updates state with division features
|
||||||
|
"""
|
||||||
|
|
||||||
|
con = duckdb.connect()
|
||||||
|
|
||||||
|
con.execute("INSTALL spatial;")
|
||||||
|
con.execute("LOAD spatial;")
|
||||||
|
|
||||||
|
con.execute("INSTALL httpfs;")
|
||||||
|
con.execute("LOAD httpfs;")
|
||||||
|
|
||||||
|
base_url = f"s3://overturemaps-us-west-2/release/{overture_release}/theme=divisions/type=division/*"
|
||||||
|
|
||||||
|
where_conditions = [
|
||||||
|
f"jaro_winkler_similarity(LOWER(names.primary), LOWER('{query}')) >= {similarity_threshold}"
|
||||||
|
]
|
||||||
|
|
||||||
|
if level:
|
||||||
|
where_conditions.append(f"subtype = '{level}'")
|
||||||
|
|
||||||
|
where_clause = " AND ".join(where_conditions)
|
||||||
|
|
||||||
|
query_sql = f"""
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
ST_AsText(ST_GeomFromWKB(geometry)) as geometry_wkt,
|
||||||
|
names.primary as name,
|
||||||
|
names.common as common_names,
|
||||||
|
subtype as division_level,
|
||||||
|
country,
|
||||||
|
region,
|
||||||
|
hierarchies,
|
||||||
|
population,
|
||||||
|
capital,
|
||||||
|
wikidata,
|
||||||
|
sources,
|
||||||
|
jaro_winkler_similarity(LOWER(names.primary), LOWER('{query}')) as similarity_score
|
||||||
|
FROM read_parquet('{base_url}', filename=true, hive_partitioning=1)
|
||||||
|
WHERE {where_clause}
|
||||||
|
ORDER BY similarity_score DESC
|
||||||
|
LIMIT {limit}
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = con.execute(query_sql).fetchall()
|
||||||
|
columns = [desc[0] for desc in con.description]
|
||||||
|
|
||||||
|
divisions = [dict(zip(columns, row)) for row in result]
|
||||||
|
|
||||||
|
# Convert divisions to GeoJSON Features
|
||||||
|
features = []
|
||||||
|
for div in divisions:
|
||||||
|
# Parse WKT geometry to GeoJSON
|
||||||
|
geom_wkt = div.get("geometry_wkt")
|
||||||
|
if geom_wkt:
|
||||||
|
shapely_geom = wkt.loads(geom_wkt)
|
||||||
|
geom_dict = mapping(shapely_geom)
|
||||||
|
|
||||||
|
# Create properties from division data
|
||||||
|
properties = {
|
||||||
|
"id": div.get("id"),
|
||||||
|
"name": div.get("name"),
|
||||||
|
"common_names": div.get("common_names"),
|
||||||
|
"division_level": div.get("division_level"),
|
||||||
|
"country": div.get("country"),
|
||||||
|
"region": div.get("region"),
|
||||||
|
"hierarchies": div.get("hierarchies"),
|
||||||
|
"population": div.get("population"),
|
||||||
|
"capital": div.get("capital"),
|
||||||
|
"wikidata": div.get("wikidata"),
|
||||||
|
"sources": div.get("sources"),
|
||||||
|
"similarity_score": div.get("similarity_score"),
|
||||||
|
}
|
||||||
|
|
||||||
|
feature = Feature(geometry=geom_dict, properties=properties)
|
||||||
|
features.append(feature)
|
||||||
|
|
||||||
|
con.close()
|
||||||
|
|
||||||
|
tool_message = f"Found {len(features)} divisions matching '{query}'"
|
||||||
|
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"features": features,
|
||||||
|
"messages": [ToolMessage(content=tool_message, tool_call_id=tool_call_id)],
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from geo_assistant.api.app import app
|
||||||
|
from geo_assistant.agent.graph import create_graph
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def initialized_app():
|
||||||
|
"""Initialize the app's chatbot before testing"""
|
||||||
|
# Manually initialize the chatbot as the lifespan would
|
||||||
|
app.state.chatbot = await create_graph()
|
||||||
|
yield app
|
||||||
|
# Cleanup if needed
|
||||||
|
if hasattr(app.state, "chatbot"):
|
||||||
|
del app.state.chatbot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hello_world(initialized_app):
|
||||||
|
"""Hello world test for the API"""
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=initialized_app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
thread_id = uuid4()
|
||||||
|
response = await client.post(
|
||||||
|
"/chat",
|
||||||
|
json={
|
||||||
|
"agent_state_input": {
|
||||||
|
"messages": [{"content": "Hello, world!", "type": "human"}],
|
||||||
|
"feature_collection": None,
|
||||||
|
},
|
||||||
|
"thread_id": str(thread_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["content-type"] == "application/x-ndjson; charset=utf-8"
|
||||||
|
|
||||||
|
# Read the streaming response
|
||||||
|
content = response.text
|
||||||
|
assert content is not None
|
||||||
|
assert len(content) > 0
|
||||||
Reference in New Issue
Block a user