diff --git a/README.md b/README.md index 361f338..663c8fe 100644 --- a/README.md +++ b/README.md @@ -25,11 +25,13 @@ The application will automatically load these variables from the `.env` file. This project uses pre-commit hooks to ensure code quality. To set up pre-commit: 1. Install dependencies (including pre-commit): + ```bash uv sync ``` -2. Install the git hooks: +1. Install the git hooks: + ```bash uv run pre-commit install ``` @@ -37,6 +39,7 @@ uv run pre-commit install Pre-commit will now automatically run ruff linting and formatting checks before each commit. To manually run pre-commit on all files: + ```bash uv run pre-commit run --all-files ``` @@ -44,7 +47,7 @@ uv run pre-commit run --all-files ## Running the API ```bash -uvicorn geo_assistant.api.app:app --reload +uv run uvicorn geo_assistant.api.app:app --reload ``` The API will be available at `http://localhost:8000`. diff --git a/pyproject.toml b/pyproject.toml index 1fa3bef..4897dfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "python-dotenv", "duckdb", "shapely", + "watchdog>=6.0.0", ] [dependency-groups] diff --git a/src/geo_assistant/agent/state.py b/src/geo_assistant/agent/state.py index 8431f26..521fadc 100644 --- a/src/geo_assistant/agent/state.py +++ b/src/geo_assistant/agent/state.py @@ -4,4 +4,4 @@ from typing import Optional class GeoAssistantState(BaseAgentState): - place: Optional[FeatureCollection] + place: Optional[FeatureCollection] = None diff --git a/src/geo_assistant/api/app.py b/src/geo_assistant/api/app.py index 4b92892..0495886 100644 --- a/src/geo_assistant/api/app.py +++ b/src/geo_assistant/api/app.py @@ -84,26 +84,22 @@ async def stream_chat( stream_mode="updates", ) - try: - async with aclosing(stream): - async for update in stream: - if await request.is_disconnected(): - logger.info("Client disconnected; stopping stream.") - break + async with aclosing(stream): + async for update in stream: + if await request.is_disconnected(): + logger.info("Client disconnected; stopping stream.") + break - agent = next(iter(update.keys())) - payload = update[agent] - if "feature_collection" not in payload: # TODO - payload["feature_collection"] = None - state_payload = GeoAssistantState(**payload) + agent = next(iter(update.keys())) + payload = update[agent] + if "place" not in payload: # TODO: why is this needed? + payload["place"] = None + state_payload = GeoAssistantState(**payload) - resp = ChatResponse(thread_id=str(thread_id), state=state_payload) + resp = ChatResponse(thread_id=str(thread_id), state=state_payload) - line = json.dumps(resp.model_dump()) + "\n" - yield line.encode("utf-8") - - except Exception as e: - logger.warning("stream_chat error: %r", e) + line = json.dumps(resp.model_dump()) + "\n" + yield line.encode("utf-8") @app.post("/chat") diff --git a/src/geo_assistant/frontend/app.py b/src/geo_assistant/frontend/app.py index 828c81a..2ab5c29 100644 --- a/src/geo_assistant/frontend/app.py +++ b/src/geo_assistant/frontend/app.py @@ -21,75 +21,50 @@ if "chat_history" not in st.session_state: st.session_state.chat_history = [] -def send_message(user_message: str, message_container): +def stream_chat(user_message: str): """Send a message to the API and stream the response.""" thread_id = st.session_state.thread_id # Prepare request body request_body = { "thread_id": thread_id, - "agent_state": { + "agent_state_input": { "messages": [{"type": "human", "content": user_message}], - "features": [], + "place": None, }, } - # Create a placeholder for streaming response - response_placeholder = message_container.empty() - last_messages = [] + with httpx.stream( + "POST", + f"{API_BASE_URL}/chat", + json=request_body, + timeout=60.0, + ) as response: + response.raise_for_status() - try: - with httpx.stream( - "POST", - f"{API_BASE_URL}/chat", - json=request_body, - timeout=60.0, - ) as response: - response.raise_for_status() + for line in response.iter_lines(): + print("=" * 100) + print(line) + print("=" * 100) - for line in response.iter_lines(): - if line: - try: - data = json.loads(line) - state = data.get("state", {}) - messages = state.get("messages", []) + if not line: + continue - # Display the latest messages - if messages: - display_parts = [] - for msg in messages: - msg_type = msg.get("type", "") - content = msg.get("content", "") + data = json.loads(line) + print("=" * 100) + print(data) + print("=" * 100) + state = data.get("state", {}) + messages = state.get("messages", []) - if msg_type == "tool": - # Display tool messages as JSON code blocks - if isinstance(content, (dict, list)): - content_str = json.dumps(content, indent=2) - else: - content_str = str(content) - display_parts.append( - f"**Tool:**\n```json\n{content_str}\n```" - ) - elif msg_type in ["ai", "assistant"]: - # Display AI messages as normal text - display_parts.append(f"**AI:** {content}") + if not messages: + continue - if display_parts: - full_response = "\n\n".join(display_parts) - response_placeholder.markdown(full_response) - last_messages = messages - except json.JSONDecodeError: - continue + for msg in messages: + msg_type = msg.get("type", "") + content = msg.get("content", "") - # Return the final messages for history - return last_messages - - except httpx.HTTPError as e: - message_container.error(f"Error connecting to API: {e}") - return [] - except Exception as e: - message_container.error(f"Error: {e}") - return [] + yield msg_type, content # Main UI @@ -112,30 +87,14 @@ for item in st.session_state.chat_history: # Chat input if prompt := st.chat_input("Type your message..."): - # Add user message to history st.session_state.chat_history.append({"role": "user", "content": prompt}) - # Send message and get response - with st.chat_message("assistant"): - final_messages = send_message(prompt, st.container()) - - # Add response to history - if final_messages: - for msg in final_messages: - msg_type = msg.get("type", "") - content = msg.get("content", "") - - if msg_type == "tool": - if isinstance(content, (dict, list)): - content_str = json.dumps(content, indent=2) - else: - content_str = str(content) - st.session_state.chat_history.append( - {"role": "assistant", "content": content_str, "is_tool": True} - ) - elif msg_type in ["ai", "assistant"]: - st.session_state.chat_history.append( - {"role": "assistant", "content": content, "is_tool": False} - ) + for msg_type, content in stream_chat(prompt): + if msg_type == "tool": + st.session_state.chat_history.append({"role": "tool", "content": content}) + elif msg_type in ["ai", "assistant"]: + st.session_state.chat_history.append( + {"role": "assistant", "content": content} + ) st.rerun() diff --git a/uv.lock b/uv.lock index 2aee836..54a3075 100644 --- a/uv.lock +++ b/uv.lock @@ -225,6 +225,7 @@ dependencies = [ { name = "shapely" }, { name = "streamlit" }, { name = "uvicorn", extra = ["standard"] }, + { name = "watchdog" }, ] [package.dev-dependencies] @@ -249,6 +250,7 @@ requires-dist = [ { name = "shapely" }, { name = "streamlit" }, { name = "uvicorn", extras = ["standard"] }, + { name = "watchdog", specifier = ">=6.0.0" }, ] [package.metadata.requires-dev] @@ -1541,6 +1543,9 @@ version = "6.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/db/7d/7f3d619e951c88ed75c6037b246ddcf2d322812ee8ea189be89511721d54/watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282", size = 131220, upload-time = "2024-11-01T14:07:13.037Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/68/98/b0345cabdce2041a01293ba483333582891a3bd5769b08eceb0d406056ef/watchdog-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:490ab2ef84f11129844c23fb14ecf30ef3d8a6abafd3754a6f75ca1e6654136c", size = 96480, upload-time = "2024-11-01T14:06:42.952Z" }, + { url = "https://files.pythonhosted.org/packages/85/83/cdf13902c626b28eedef7ec4f10745c52aad8a8fe7eb04ed7b1f111ca20e/watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134", size = 88451, upload-time = "2024-11-01T14:06:45.084Z" }, + { url = "https://files.pythonhosted.org/packages/fe/c4/225c87bae08c8b9ec99030cd48ae9c4eca050a59bf5c2255853e18c87b50/watchdog-6.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a175f755fc2279e0b7312c0035d52e27211a5bc39719dd529625b1930917345b", size = 89057, upload-time = "2024-11-01T14:06:47.324Z" }, { url = "https://files.pythonhosted.org/packages/a9/c7/ca4bf3e518cb57a686b2feb4f55a1892fd9a3dd13f470fca14e00f80ea36/watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13", size = 79079, upload-time = "2024-11-01T14:06:59.472Z" }, { url = "https://files.pythonhosted.org/packages/5c/51/d46dc9332f9a647593c947b4b88e2381c8dfc0942d15b8edc0310fa4abb1/watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379", size = 79078, upload-time = "2024-11-01T14:07:01.431Z" }, { url = "https://files.pythonhosted.org/packages/d4/57/04edbf5e169cd318d5f07b4766fee38e825d64b6913ca157ca32d1a42267/watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e", size = 79076, upload-time = "2024-11-01T14:07:02.568Z" },