mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Initial commit
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
name: Ruff
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/ruff-action@v3
|
||||
with:
|
||||
args: check
|
||||
- uses: astral-sh/ruff-action@v3
|
||||
with:
|
||||
args: format --check
|
||||
@@ -205,3 +205,6 @@ cython_debug/
|
||||
marimo/_static/
|
||||
marimo/_lsp/
|
||||
__marimo__/
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.9
|
||||
hooks:
|
||||
# Run the linter
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
# Run the formatter
|
||||
- id: ruff-format
|
||||
@@ -0,0 +1,58 @@
|
||||
# Geo Assistant
|
||||
|
||||
A geographic assistant that helps answer questions and perform tasks related to locations and geographic data.
|
||||
|
||||
## Environment Setup
|
||||
|
||||
The project uses environment variables for configuration. Copy `.env.example` to `.env` and customize as needed:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
Edit `.env` to set your configuration:
|
||||
|
||||
- `OLLAMA_MODEL`: Model name (default: `llama3.2`)
|
||||
- `OLLAMA_BASE_URL`: Ollama server URL (default: `http://localhost:11434`)
|
||||
- `API_BASE_URL`: API base URL for the frontend (default: `http://localhost:8000`)
|
||||
|
||||
The application will automatically load these variables from the `.env` file.
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Pre-commit Hooks
|
||||
|
||||
This project uses pre-commit hooks to ensure code quality. To set up pre-commit:
|
||||
|
||||
1. Install dependencies (including pre-commit):
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
2. Install the git hooks:
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
Pre-commit will now automatically run ruff linting and formatting checks before each commit.
|
||||
|
||||
To manually run pre-commit on all files:
|
||||
```bash
|
||||
uv run pre-commit run --all-files
|
||||
```
|
||||
|
||||
## Running the API
|
||||
|
||||
```bash
|
||||
uvicorn geo_assistant.api.app:app --reload
|
||||
```
|
||||
|
||||
The API will be available at `http://localhost:8000`.
|
||||
|
||||
## Running the Frontend
|
||||
|
||||
```bash
|
||||
streamlit run src/geo_assistant/frontend/app.py
|
||||
```
|
||||
|
||||
The frontend will be available at `http://localhost:8501`.
|
||||
@@ -0,0 +1,35 @@
|
||||
[project]
|
||||
name = "geo-assistant"
|
||||
version = "0.0.1"
|
||||
description = "Geo Assistant"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"langgraph",
|
||||
"fastapi",
|
||||
"pydantic",
|
||||
"geojson_pydantic",
|
||||
"streamlit",
|
||||
"httpx",
|
||||
"uvicorn[standard]",
|
||||
"langchain-ollama",
|
||||
"langchain",
|
||||
"python-dotenv",
|
||||
"duckdb",
|
||||
"shapely",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff",
|
||||
"pytest",
|
||||
"pytest-asyncio",
|
||||
"pre-commit",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/geo_assistant"]
|
||||
@@ -0,0 +1,29 @@
|
||||
import datetime
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langchain.agents import create_agent
|
||||
from geo_assistant.agent.state import AgentState
|
||||
from geo_assistant.agent.llms import llm
|
||||
from geo_assistant.tools.overture import get_overture_locations
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that can answer questions and help with tasks.
|
||||
|
||||
You have location and division tools available to you. Only use this data if the user asks for it.
|
||||
|
||||
The current date and time is {now}.
|
||||
"""
|
||||
|
||||
|
||||
async def create_graph():
|
||||
checkpointer = InMemorySaver()
|
||||
graph = create_agent(
|
||||
model=llm,
|
||||
tools=[get_overture_locations], # [get_overture_locations, geocode_division],
|
||||
system_prompt=SYSTEM_PROMPT.format(
|
||||
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
),
|
||||
state_schema=AgentState,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
return graph
|
||||
@@ -0,0 +1,16 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Get model name from environment variable, default to llama3.2
|
||||
MODEL_NAME = os.environ.get("OLLAMA_MODEL", "llama3.2")
|
||||
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434")
|
||||
|
||||
llm = ChatOllama(
|
||||
model=MODEL_NAME,
|
||||
base_url=OLLAMA_BASE_URL,
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
from langchain.agents import AgentState as BaseAgentState
|
||||
from geojson_pydantic import FeatureCollection
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class AgentState(BaseAgentState):
|
||||
feature_collection: Optional[FeatureCollection] = Field(
|
||||
default=None, description="FeatureCollection to be used for the analysis"
|
||||
)
|
||||
@@ -0,0 +1,125 @@
|
||||
import json
|
||||
from contextlib import aclosing, asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
import logging
|
||||
from pydantic import UUID4
|
||||
|
||||
from geo_assistant.agent.graph import create_graph
|
||||
from geo_assistant.agent.state import AgentState
|
||||
from geo_assistant.api.schemas.chat import ChatRequestBody, ChatResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Whitelist state fields that can be set by the user.
|
||||
# Note that these attrs need to be pydantic Fields and
|
||||
# need a description in the AgentState model.
|
||||
UI_SET_FIELDS_WHITELIST = ["feature_collection", "messages"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
app.state.chatbot = await create_graph()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Geo Assistant", lifespan=lifespan)
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all HTTP methods (GET, POST, PUT, DELETE, etc.)
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
|
||||
async def stream_chat(
|
||||
ui_state_update: AgentState,
|
||||
thread_id: UUID4,
|
||||
chatbot: Any,
|
||||
request: Request,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
config: Dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": str(thread_id),
|
||||
}
|
||||
}
|
||||
|
||||
state_updates = {}
|
||||
|
||||
vars_to_update = {
|
||||
key: val
|
||||
for key, val in ui_state_update.items()
|
||||
if val and key in UI_SET_FIELDS_WHITELIST
|
||||
}
|
||||
logger.debug(f"State variables to update: {vars_to_update}")
|
||||
|
||||
ui_messages = []
|
||||
for key in vars_to_update.keys():
|
||||
if hasattr(AgentState, key):
|
||||
field_info = getattr(AgentState, key)
|
||||
description = field_info.description if field_info else f"Field {key}"
|
||||
if description:
|
||||
ui_messages.append(
|
||||
{
|
||||
"content": f"Manually selected data for field {key}: {description}",
|
||||
"type": "human",
|
||||
}
|
||||
)
|
||||
|
||||
# Add UI messages to the existing messages if they exist
|
||||
existing_messages = vars_to_update.get("messages", [])
|
||||
vars_to_update["messages"] = ui_messages + existing_messages
|
||||
|
||||
state_updates.update(vars_to_update)
|
||||
|
||||
stream = chatbot.astream(
|
||||
input=state_updates,
|
||||
config=config,
|
||||
stream_mode="updates",
|
||||
)
|
||||
|
||||
try:
|
||||
async with aclosing(stream):
|
||||
async for update in stream:
|
||||
if await request.is_disconnected():
|
||||
logger.info("Client disconnected; stopping stream.")
|
||||
break
|
||||
|
||||
agent = next(iter(update.keys()))
|
||||
payload = update[agent]
|
||||
if "feature_collection" not in payload: # TODO
|
||||
payload["feature_collection"] = None
|
||||
state_payload = AgentState(**payload)
|
||||
|
||||
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
|
||||
|
||||
line = json.dumps(resp.model_dump()) + "\n"
|
||||
yield line.encode("utf-8")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("stream_chat error: %r", e)
|
||||
|
||||
|
||||
@app.post("/chat")
|
||||
async def chat(request: ChatRequestBody, http_request: Request) -> StreamingResponse:
|
||||
generator = stream_chat(
|
||||
ui_state_update=request.agent_state_input,
|
||||
thread_id=request.thread_id,
|
||||
chatbot=http_request.app.state.chatbot,
|
||||
request=http_request,
|
||||
)
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
media_type="application/x-ndjson; charset=utf-8",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
# If you run behind nginx, this prevents buffering of the stream:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel
|
||||
from geo_assistant.agent.state import AgentState
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
agent_state_input: AgentState
|
||||
thread_id: str
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
thread_id: str
|
||||
state: AgentState
|
||||
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
import streamlit as st
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# API configuration
|
||||
API_BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:8000")
|
||||
|
||||
st.set_page_config(page_title="Geo Assistant", page_icon="💬")
|
||||
|
||||
# Initialize session state
|
||||
if "thread_id" not in st.session_state:
|
||||
st.session_state.thread_id = str(uuid.uuid4())
|
||||
if "chat_history" not in st.session_state:
|
||||
st.session_state.chat_history = []
|
||||
|
||||
|
||||
def send_message(user_message: str, message_container):
|
||||
"""Send a message to the API and stream the response."""
|
||||
thread_id = st.session_state.thread_id
|
||||
|
||||
# Prepare request body
|
||||
request_body = {
|
||||
"thread_id": thread_id,
|
||||
"agent_state": {
|
||||
"messages": [{"type": "human", "content": user_message}],
|
||||
"features": [],
|
||||
},
|
||||
}
|
||||
|
||||
# Create a placeholder for streaming response
|
||||
response_placeholder = message_container.empty()
|
||||
last_messages = []
|
||||
|
||||
try:
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{API_BASE_URL}/chat",
|
||||
json=request_body,
|
||||
timeout=60.0,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
state = data.get("state", {})
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# Display the latest messages
|
||||
if messages:
|
||||
display_parts = []
|
||||
for msg in messages:
|
||||
msg_type = msg.get("type", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if msg_type == "tool":
|
||||
# Display tool messages as JSON code blocks
|
||||
if isinstance(content, (dict, list)):
|
||||
content_str = json.dumps(content, indent=2)
|
||||
else:
|
||||
content_str = str(content)
|
||||
display_parts.append(
|
||||
f"**Tool:**\n```json\n{content_str}\n```"
|
||||
)
|
||||
elif msg_type in ["ai", "assistant"]:
|
||||
# Display AI messages as normal text
|
||||
display_parts.append(f"**AI:** {content}")
|
||||
|
||||
if display_parts:
|
||||
full_response = "\n\n".join(display_parts)
|
||||
response_placeholder.markdown(full_response)
|
||||
last_messages = messages
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Return the final messages for history
|
||||
return last_messages
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
message_container.error(f"Error connecting to API: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
message_container.error(f"Error: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# Main UI
|
||||
st.title("Geo Assistant")
|
||||
|
||||
# Display chat history
|
||||
for item in st.session_state.chat_history:
|
||||
role = item["role"]
|
||||
content = item["content"]
|
||||
|
||||
with st.chat_message(role):
|
||||
if role == "assistant":
|
||||
# For assistant messages, check if it's a tool message
|
||||
if item.get("is_tool"):
|
||||
st.code(content, language="json")
|
||||
else:
|
||||
st.markdown(content)
|
||||
else:
|
||||
st.markdown(content)
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Type your message..."):
|
||||
# Add user message to history
|
||||
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
||||
|
||||
# Send message and get response
|
||||
with st.chat_message("assistant"):
|
||||
final_messages = send_message(prompt, st.container())
|
||||
|
||||
# Add response to history
|
||||
if final_messages:
|
||||
for msg in final_messages:
|
||||
msg_type = msg.get("type", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if msg_type == "tool":
|
||||
if isinstance(content, (dict, list)):
|
||||
content_str = json.dumps(content, indent=2)
|
||||
else:
|
||||
content_str = str(content)
|
||||
st.session_state.chat_history.append(
|
||||
{"role": "assistant", "content": content_str, "is_tool": True}
|
||||
)
|
||||
elif msg_type in ["ai", "assistant"]:
|
||||
st.session_state.chat_history.append(
|
||||
{"role": "assistant", "content": content, "is_tool": False}
|
||||
)
|
||||
|
||||
st.rerun()
|
||||
@@ -0,0 +1,248 @@
|
||||
from typing import Optional, Annotated
|
||||
import duckdb
|
||||
from geojson_pydantic import Feature
|
||||
from shapely import wkt
|
||||
from shapely.geometry import mapping
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
|
||||
|
||||
@tool
|
||||
def get_overture_locations(
|
||||
area_of_interest: Feature,
|
||||
place_name: Optional[str] = None,
|
||||
place_type: Optional[str] = None,
|
||||
overture_release: str = "2024-11-13.0",
|
||||
similarity_threshold: float = 0.6,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||
) -> Command:
|
||||
"""
|
||||
Get locations from Overture Maps.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
area_of_interest : Feature
|
||||
Area of interest to search for locations in
|
||||
place_name : str, optional
|
||||
Name of the place to search for
|
||||
place_type : str, optional
|
||||
Type of the place to search for
|
||||
overture_release : str
|
||||
Overture Maps release version
|
||||
similarity_threshold : float
|
||||
Minimum similarity score (0-1) for fuzzy name matching
|
||||
tool_call_id : str
|
||||
Tool call ID
|
||||
|
||||
Returns
|
||||
-------
|
||||
Command
|
||||
Command that updates state with location features
|
||||
"""
|
||||
|
||||
con = duckdb.connect()
|
||||
|
||||
con.execute("INSTALL spatial;")
|
||||
con.execute("LOAD spatial;")
|
||||
|
||||
con.execute("INSTALL httpfs;")
|
||||
con.execute("LOAD httpfs;")
|
||||
|
||||
con.execute(
|
||||
"""
|
||||
CREATE OR REPLACE TABLE aoi AS
|
||||
SELECT ST_GeomFromGeoJSON(?) AS geom
|
||||
""",
|
||||
[area_of_interest.geometry.model_dump_json()],
|
||||
)
|
||||
|
||||
base_url = f"s3://overturemaps-us-west-2/release/{overture_release}/theme=places/type=place/*"
|
||||
|
||||
where_conditions = ["ST_Within(ST_GeomFromWKB(geometry), (SELECT geom FROM aoi))"]
|
||||
|
||||
if place_type:
|
||||
where_conditions.append(f"categories.primary = '{place_type}'")
|
||||
|
||||
if place_name:
|
||||
where_conditions.append(
|
||||
f"jaro_winkler_similarity(LOWER(names.primary), LOWER('{place_name}')) >= {similarity_threshold}"
|
||||
)
|
||||
|
||||
where_clause = " AND ".join(where_conditions)
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
id,
|
||||
ST_AsText(ST_GeomFromWKB(geometry)) as geometry_wkt,
|
||||
names.primary as name,
|
||||
categories.primary as primary_category,
|
||||
confidence,
|
||||
websites,
|
||||
phones,
|
||||
addresses
|
||||
FROM read_parquet('{base_url}', filename=true, hive_partitioning=1)
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
|
||||
result = con.execute(query).fetchall()
|
||||
columns = [desc[0] for desc in con.description]
|
||||
|
||||
locations = [dict(zip(columns, row)) for row in result]
|
||||
|
||||
# Convert locations to GeoJSON Features
|
||||
features = []
|
||||
for loc in locations:
|
||||
# Parse WKT geometry to GeoJSON
|
||||
geom_wkt = loc.get("geometry_wkt")
|
||||
if geom_wkt:
|
||||
shapely_geom = wkt.loads(geom_wkt)
|
||||
geom_dict = mapping(shapely_geom)
|
||||
|
||||
# Create properties from location data
|
||||
properties = {
|
||||
"id": loc.get("id"),
|
||||
"name": loc.get("name"),
|
||||
"primary_category": loc.get("primary_category"),
|
||||
"confidence": loc.get("confidence"),
|
||||
"websites": loc.get("websites"),
|
||||
"phones": loc.get("phones"),
|
||||
"addresses": loc.get("addresses"),
|
||||
}
|
||||
|
||||
feature = Feature(geometry=geom_dict, properties=properties)
|
||||
features.append(feature)
|
||||
|
||||
con.close()
|
||||
|
||||
tool_message = f"Found {len(features)} locations matching the criteria"
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"features": features,
|
||||
"messages": [ToolMessage(content=tool_message, tool_call_id=tool_call_id)],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def geocode_division(
|
||||
query: str,
|
||||
level: Optional[str] = None,
|
||||
overture_release: str = "2024-11-13.0",
|
||||
similarity_threshold: float = 0.6,
|
||||
limit: int = 10,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId] = "",
|
||||
) -> Command:
|
||||
"""
|
||||
Geocode a place name using Overture divisions data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : str
|
||||
Place name to search for (e.g., "San Francisco", "California", "United States")
|
||||
level : str, optional
|
||||
Division level to filter by. Options:
|
||||
- 'country'
|
||||
- 'region' (states, provinces)
|
||||
- 'county' (counties, districts)
|
||||
- 'locality' (cities, towns)
|
||||
- 'localadmin' (local administrative areas)
|
||||
- 'neighborhood'
|
||||
overture_release : str
|
||||
Overture Maps release version
|
||||
similarity_threshold : float
|
||||
Minimum similarity score (0-1) for fuzzy name matching
|
||||
limit : int
|
||||
Maximum number of results to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
Command
|
||||
Command that updates state with division features
|
||||
"""
|
||||
|
||||
con = duckdb.connect()
|
||||
|
||||
con.execute("INSTALL spatial;")
|
||||
con.execute("LOAD spatial;")
|
||||
|
||||
con.execute("INSTALL httpfs;")
|
||||
con.execute("LOAD httpfs;")
|
||||
|
||||
base_url = f"s3://overturemaps-us-west-2/release/{overture_release}/theme=divisions/type=division/*"
|
||||
|
||||
where_conditions = [
|
||||
f"jaro_winkler_similarity(LOWER(names.primary), LOWER('{query}')) >= {similarity_threshold}"
|
||||
]
|
||||
|
||||
if level:
|
||||
where_conditions.append(f"subtype = '{level}'")
|
||||
|
||||
where_clause = " AND ".join(where_conditions)
|
||||
|
||||
query_sql = f"""
|
||||
SELECT
|
||||
id,
|
||||
ST_AsText(ST_GeomFromWKB(geometry)) as geometry_wkt,
|
||||
names.primary as name,
|
||||
names.common as common_names,
|
||||
subtype as division_level,
|
||||
country,
|
||||
region,
|
||||
hierarchies,
|
||||
population,
|
||||
capital,
|
||||
wikidata,
|
||||
sources,
|
||||
jaro_winkler_similarity(LOWER(names.primary), LOWER('{query}')) as similarity_score
|
||||
FROM read_parquet('{base_url}', filename=true, hive_partitioning=1)
|
||||
WHERE {where_clause}
|
||||
ORDER BY similarity_score DESC
|
||||
LIMIT {limit}
|
||||
"""
|
||||
|
||||
result = con.execute(query_sql).fetchall()
|
||||
columns = [desc[0] for desc in con.description]
|
||||
|
||||
divisions = [dict(zip(columns, row)) for row in result]
|
||||
|
||||
# Convert divisions to GeoJSON Features
|
||||
features = []
|
||||
for div in divisions:
|
||||
# Parse WKT geometry to GeoJSON
|
||||
geom_wkt = div.get("geometry_wkt")
|
||||
if geom_wkt:
|
||||
shapely_geom = wkt.loads(geom_wkt)
|
||||
geom_dict = mapping(shapely_geom)
|
||||
|
||||
# Create properties from division data
|
||||
properties = {
|
||||
"id": div.get("id"),
|
||||
"name": div.get("name"),
|
||||
"common_names": div.get("common_names"),
|
||||
"division_level": div.get("division_level"),
|
||||
"country": div.get("country"),
|
||||
"region": div.get("region"),
|
||||
"hierarchies": div.get("hierarchies"),
|
||||
"population": div.get("population"),
|
||||
"capital": div.get("capital"),
|
||||
"wikidata": div.get("wikidata"),
|
||||
"sources": div.get("sources"),
|
||||
"similarity_score": div.get("similarity_score"),
|
||||
}
|
||||
|
||||
feature = Feature(geometry=geom_dict, properties=properties)
|
||||
features.append(feature)
|
||||
|
||||
con.close()
|
||||
|
||||
tool_message = f"Found {len(features)} divisions matching '{query}'"
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"features": features,
|
||||
"messages": [ToolMessage(content=tool_message, tool_call_id=tool_call_id)],
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from uuid import uuid4
|
||||
|
||||
from geo_assistant.api.app import app
|
||||
from geo_assistant.agent.graph import create_graph
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def initialized_app():
|
||||
"""Initialize the app's chatbot before testing"""
|
||||
# Manually initialize the chatbot as the lifespan would
|
||||
app.state.chatbot = await create_graph()
|
||||
yield app
|
||||
# Cleanup if needed
|
||||
if hasattr(app.state, "chatbot"):
|
||||
del app.state.chatbot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hello_world(initialized_app):
|
||||
"""Hello world test for the API"""
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=initialized_app), base_url="http://test"
|
||||
) as client:
|
||||
thread_id = uuid4()
|
||||
response = await client.post(
|
||||
"/chat",
|
||||
json={
|
||||
"agent_state_input": {
|
||||
"messages": [{"content": "Hello, world!", "type": "human"}],
|
||||
"feature_collection": None,
|
||||
},
|
||||
"thread_id": str(thread_id),
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/x-ndjson; charset=utf-8"
|
||||
|
||||
# Read the streaming response
|
||||
content = response.text
|
||||
assert content is not None
|
||||
assert len(content) > 0
|
||||
Reference in New Issue
Block a user