mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-15 15:31:02 +02:00
Fix frontend and api (#9)
This commit is contained in:
@@ -25,11 +25,13 @@ The application will automatically load these variables from the `.env` file.
|
|||||||
This project uses pre-commit hooks to ensure code quality. To set up pre-commit:
|
This project uses pre-commit hooks to ensure code quality. To set up pre-commit:
|
||||||
|
|
||||||
1. Install dependencies (including pre-commit):
|
1. Install dependencies (including pre-commit):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv sync
|
uv sync
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install the git hooks:
|
1. Install the git hooks:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run pre-commit install
|
uv run pre-commit install
|
||||||
```
|
```
|
||||||
@@ -37,6 +39,7 @@ uv run pre-commit install
|
|||||||
Pre-commit will now automatically run ruff linting and formatting checks before each commit.
|
Pre-commit will now automatically run ruff linting and formatting checks before each commit.
|
||||||
|
|
||||||
To manually run pre-commit on all files:
|
To manually run pre-commit on all files:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run pre-commit run --all-files
|
uv run pre-commit run --all-files
|
||||||
```
|
```
|
||||||
@@ -44,7 +47,7 @@ uv run pre-commit run --all-files
|
|||||||
## Running the API
|
## Running the API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uvicorn geo_assistant.api.app:app --reload
|
uv run uvicorn geo_assistant.api.app:app --reload
|
||||||
```
|
```
|
||||||
|
|
||||||
The API will be available at `http://localhost:8000`.
|
The API will be available at `http://localhost:8000`.
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ dependencies = [
|
|||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"duckdb",
|
"duckdb",
|
||||||
"shapely",
|
"shapely",
|
||||||
|
"watchdog>=6.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ from typing import Optional
|
|||||||
|
|
||||||
|
|
||||||
class GeoAssistantState(BaseAgentState):
|
class GeoAssistantState(BaseAgentState):
|
||||||
place: Optional[FeatureCollection]
|
place: Optional[FeatureCollection] = None
|
||||||
|
|||||||
@@ -84,26 +84,22 @@ async def stream_chat(
|
|||||||
stream_mode="updates",
|
stream_mode="updates",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
async with aclosing(stream):
|
||||||
async with aclosing(stream):
|
async for update in stream:
|
||||||
async for update in stream:
|
if await request.is_disconnected():
|
||||||
if await request.is_disconnected():
|
logger.info("Client disconnected; stopping stream.")
|
||||||
logger.info("Client disconnected; stopping stream.")
|
break
|
||||||
break
|
|
||||||
|
|
||||||
agent = next(iter(update.keys()))
|
agent = next(iter(update.keys()))
|
||||||
payload = update[agent]
|
payload = update[agent]
|
||||||
if "feature_collection" not in payload: # TODO
|
if "place" not in payload: # TODO: why is this needed?
|
||||||
payload["feature_collection"] = None
|
payload["place"] = None
|
||||||
state_payload = GeoAssistantState(**payload)
|
state_payload = GeoAssistantState(**payload)
|
||||||
|
|
||||||
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
|
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
|
||||||
|
|
||||||
line = json.dumps(resp.model_dump()) + "\n"
|
line = json.dumps(resp.model_dump()) + "\n"
|
||||||
yield line.encode("utf-8")
|
yield line.encode("utf-8")
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("stream_chat error: %r", e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/chat")
|
@app.post("/chat")
|
||||||
|
|||||||
@@ -21,75 +21,50 @@ if "chat_history" not in st.session_state:
|
|||||||
st.session_state.chat_history = []
|
st.session_state.chat_history = []
|
||||||
|
|
||||||
|
|
||||||
def send_message(user_message: str, message_container):
|
def stream_chat(user_message: str):
|
||||||
"""Send a message to the API and stream the response."""
|
"""Send a message to the API and stream the response."""
|
||||||
thread_id = st.session_state.thread_id
|
thread_id = st.session_state.thread_id
|
||||||
|
|
||||||
# Prepare request body
|
# Prepare request body
|
||||||
request_body = {
|
request_body = {
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
"agent_state": {
|
"agent_state_input": {
|
||||||
"messages": [{"type": "human", "content": user_message}],
|
"messages": [{"type": "human", "content": user_message}],
|
||||||
"features": [],
|
"place": None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create a placeholder for streaming response
|
with httpx.stream(
|
||||||
response_placeholder = message_container.empty()
|
"POST",
|
||||||
last_messages = []
|
f"{API_BASE_URL}/chat",
|
||||||
|
json=request_body,
|
||||||
|
timeout=60.0,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
try:
|
for line in response.iter_lines():
|
||||||
with httpx.stream(
|
print("=" * 100)
|
||||||
"POST",
|
print(line)
|
||||||
f"{API_BASE_URL}/chat",
|
print("=" * 100)
|
||||||
json=request_body,
|
|
||||||
timeout=60.0,
|
|
||||||
) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
for line in response.iter_lines():
|
if not line:
|
||||||
if line:
|
continue
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
state = data.get("state", {})
|
|
||||||
messages = state.get("messages", [])
|
|
||||||
|
|
||||||
# Display the latest messages
|
data = json.loads(line)
|
||||||
if messages:
|
print("=" * 100)
|
||||||
display_parts = []
|
print(data)
|
||||||
for msg in messages:
|
print("=" * 100)
|
||||||
msg_type = msg.get("type", "")
|
state = data.get("state", {})
|
||||||
content = msg.get("content", "")
|
messages = state.get("messages", [])
|
||||||
|
|
||||||
if msg_type == "tool":
|
if not messages:
|
||||||
# Display tool messages as JSON code blocks
|
continue
|
||||||
if isinstance(content, (dict, list)):
|
|
||||||
content_str = json.dumps(content, indent=2)
|
|
||||||
else:
|
|
||||||
content_str = str(content)
|
|
||||||
display_parts.append(
|
|
||||||
f"**Tool:**\n```json\n{content_str}\n```"
|
|
||||||
)
|
|
||||||
elif msg_type in ["ai", "assistant"]:
|
|
||||||
# Display AI messages as normal text
|
|
||||||
display_parts.append(f"**AI:** {content}")
|
|
||||||
|
|
||||||
if display_parts:
|
for msg in messages:
|
||||||
full_response = "\n\n".join(display_parts)
|
msg_type = msg.get("type", "")
|
||||||
response_placeholder.markdown(full_response)
|
content = msg.get("content", "")
|
||||||
last_messages = messages
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Return the final messages for history
|
yield msg_type, content
|
||||||
return last_messages
|
|
||||||
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
message_container.error(f"Error connecting to API: {e}")
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
|
||||||
message_container.error(f"Error: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
# Main UI
|
# Main UI
|
||||||
@@ -112,30 +87,14 @@ for item in st.session_state.chat_history:
|
|||||||
|
|
||||||
# Chat input
|
# Chat input
|
||||||
if prompt := st.chat_input("Type your message..."):
|
if prompt := st.chat_input("Type your message..."):
|
||||||
# Add user message to history
|
|
||||||
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
# Send message and get response
|
for msg_type, content in stream_chat(prompt):
|
||||||
with st.chat_message("assistant"):
|
if msg_type == "tool":
|
||||||
final_messages = send_message(prompt, st.container())
|
st.session_state.chat_history.append({"role": "tool", "content": content})
|
||||||
|
elif msg_type in ["ai", "assistant"]:
|
||||||
# Add response to history
|
st.session_state.chat_history.append(
|
||||||
if final_messages:
|
{"role": "assistant", "content": content}
|
||||||
for msg in final_messages:
|
)
|
||||||
msg_type = msg.get("type", "")
|
|
||||||
content = msg.get("content", "")
|
|
||||||
|
|
||||||
if msg_type == "tool":
|
|
||||||
if isinstance(content, (dict, list)):
|
|
||||||
content_str = json.dumps(content, indent=2)
|
|
||||||
else:
|
|
||||||
content_str = str(content)
|
|
||||||
st.session_state.chat_history.append(
|
|
||||||
{"role": "assistant", "content": content_str, "is_tool": True}
|
|
||||||
)
|
|
||||||
elif msg_type in ["ai", "assistant"]:
|
|
||||||
st.session_state.chat_history.append(
|
|
||||||
{"role": "assistant", "content": content, "is_tool": False}
|
|
||||||
)
|
|
||||||
|
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ dependencies = [
|
|||||||
{ name = "shapely" },
|
{ name = "shapely" },
|
||||||
{ name = "streamlit" },
|
{ name = "streamlit" },
|
||||||
{ name = "uvicorn", extra = ["standard"] },
|
{ name = "uvicorn", extra = ["standard"] },
|
||||||
|
{ name = "watchdog" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dev-dependencies]
|
[package.dev-dependencies]
|
||||||
@@ -249,6 +250,7 @@ requires-dist = [
|
|||||||
{ name = "shapely" },
|
{ name = "shapely" },
|
||||||
{ name = "streamlit" },
|
{ name = "streamlit" },
|
||||||
{ name = "uvicorn", extras = ["standard"] },
|
{ name = "uvicorn", extras = ["standard"] },
|
||||||
|
{ name = "watchdog", specifier = ">=6.0.0" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.metadata.requires-dev]
|
[package.metadata.requires-dev]
|
||||||
@@ -1541,6 +1543,9 @@ version = "6.0.0"
|
|||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/db/7d/7f3d619e951c88ed75c6037b246ddcf2d322812ee8ea189be89511721d54/watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282", size = 131220, upload-time = "2024-11-01T14:07:13.037Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/db/7d/7f3d619e951c88ed75c6037b246ddcf2d322812ee8ea189be89511721d54/watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282", size = 131220, upload-time = "2024-11-01T14:07:13.037Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/68/98/b0345cabdce2041a01293ba483333582891a3bd5769b08eceb0d406056ef/watchdog-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:490ab2ef84f11129844c23fb14ecf30ef3d8a6abafd3754a6f75ca1e6654136c", size = 96480, upload-time = "2024-11-01T14:06:42.952Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/85/83/cdf13902c626b28eedef7ec4f10745c52aad8a8fe7eb04ed7b1f111ca20e/watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134", size = 88451, upload-time = "2024-11-01T14:06:45.084Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fe/c4/225c87bae08c8b9ec99030cd48ae9c4eca050a59bf5c2255853e18c87b50/watchdog-6.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a175f755fc2279e0b7312c0035d52e27211a5bc39719dd529625b1930917345b", size = 89057, upload-time = "2024-11-01T14:06:47.324Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/a9/c7/ca4bf3e518cb57a686b2feb4f55a1892fd9a3dd13f470fca14e00f80ea36/watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13", size = 79079, upload-time = "2024-11-01T14:06:59.472Z" },
|
{ url = "https://files.pythonhosted.org/packages/a9/c7/ca4bf3e518cb57a686b2feb4f55a1892fd9a3dd13f470fca14e00f80ea36/watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13", size = 79079, upload-time = "2024-11-01T14:06:59.472Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/5c/51/d46dc9332f9a647593c947b4b88e2381c8dfc0942d15b8edc0310fa4abb1/watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379", size = 79078, upload-time = "2024-11-01T14:07:01.431Z" },
|
{ url = "https://files.pythonhosted.org/packages/5c/51/d46dc9332f9a647593c947b4b88e2381c8dfc0942d15b8edc0310fa4abb1/watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379", size = 79078, upload-time = "2024-11-01T14:07:01.431Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/d4/57/04edbf5e169cd318d5f07b4766fee38e825d64b6913ca157ca32d1a42267/watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e", size = 79076, upload-time = "2024-11-01T14:07:02.568Z" },
|
{ url = "https://files.pythonhosted.org/packages/d4/57/04edbf5e169cd318d5f07b4766fee38e825d64b6913ca157ca32d1a42267/watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e", size = 79076, upload-time = "2024-11-01T14:07:02.568Z" },
|
||||||
|
|||||||
Reference in New Issue
Block a user