mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Fix frontend and api (#9)
This commit is contained in:
@@ -4,4 +4,4 @@ from typing import Optional
|
||||
|
||||
|
||||
class GeoAssistantState(BaseAgentState):
|
||||
place: Optional[FeatureCollection]
|
||||
place: Optional[FeatureCollection] = None
|
||||
|
||||
@@ -84,26 +84,22 @@ async def stream_chat(
|
||||
stream_mode="updates",
|
||||
)
|
||||
|
||||
try:
|
||||
async with aclosing(stream):
|
||||
async for update in stream:
|
||||
if await request.is_disconnected():
|
||||
logger.info("Client disconnected; stopping stream.")
|
||||
break
|
||||
async with aclosing(stream):
|
||||
async for update in stream:
|
||||
if await request.is_disconnected():
|
||||
logger.info("Client disconnected; stopping stream.")
|
||||
break
|
||||
|
||||
agent = next(iter(update.keys()))
|
||||
payload = update[agent]
|
||||
if "feature_collection" not in payload: # TODO
|
||||
payload["feature_collection"] = None
|
||||
state_payload = GeoAssistantState(**payload)
|
||||
agent = next(iter(update.keys()))
|
||||
payload = update[agent]
|
||||
if "place" not in payload: # TODO: why is this needed?
|
||||
payload["place"] = None
|
||||
state_payload = GeoAssistantState(**payload)
|
||||
|
||||
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
|
||||
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
|
||||
|
||||
line = json.dumps(resp.model_dump()) + "\n"
|
||||
yield line.encode("utf-8")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("stream_chat error: %r", e)
|
||||
line = json.dumps(resp.model_dump()) + "\n"
|
||||
yield line.encode("utf-8")
|
||||
|
||||
|
||||
@app.post("/chat")
|
||||
|
||||
@@ -21,75 +21,50 @@ if "chat_history" not in st.session_state:
|
||||
st.session_state.chat_history = []
|
||||
|
||||
|
||||
def send_message(user_message: str, message_container):
|
||||
def stream_chat(user_message: str):
|
||||
"""Send a message to the API and stream the response."""
|
||||
thread_id = st.session_state.thread_id
|
||||
|
||||
# Prepare request body
|
||||
request_body = {
|
||||
"thread_id": thread_id,
|
||||
"agent_state": {
|
||||
"agent_state_input": {
|
||||
"messages": [{"type": "human", "content": user_message}],
|
||||
"features": [],
|
||||
"place": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Create a placeholder for streaming response
|
||||
response_placeholder = message_container.empty()
|
||||
last_messages = []
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{API_BASE_URL}/chat",
|
||||
json=request_body,
|
||||
timeout=60.0,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{API_BASE_URL}/chat",
|
||||
json=request_body,
|
||||
timeout=60.0,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
for line in response.iter_lines():
|
||||
print("=" * 100)
|
||||
print(line)
|
||||
print("=" * 100)
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
state = data.get("state", {})
|
||||
messages = state.get("messages", [])
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Display the latest messages
|
||||
if messages:
|
||||
display_parts = []
|
||||
for msg in messages:
|
||||
msg_type = msg.get("type", "")
|
||||
content = msg.get("content", "")
|
||||
data = json.loads(line)
|
||||
print("=" * 100)
|
||||
print(data)
|
||||
print("=" * 100)
|
||||
state = data.get("state", {})
|
||||
messages = state.get("messages", [])
|
||||
|
||||
if msg_type == "tool":
|
||||
# Display tool messages as JSON code blocks
|
||||
if isinstance(content, (dict, list)):
|
||||
content_str = json.dumps(content, indent=2)
|
||||
else:
|
||||
content_str = str(content)
|
||||
display_parts.append(
|
||||
f"**Tool:**\n```json\n{content_str}\n```"
|
||||
)
|
||||
elif msg_type in ["ai", "assistant"]:
|
||||
# Display AI messages as normal text
|
||||
display_parts.append(f"**AI:** {content}")
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
if display_parts:
|
||||
full_response = "\n\n".join(display_parts)
|
||||
response_placeholder.markdown(full_response)
|
||||
last_messages = messages
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
for msg in messages:
|
||||
msg_type = msg.get("type", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Return the final messages for history
|
||||
return last_messages
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
message_container.error(f"Error connecting to API: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
message_container.error(f"Error: {e}")
|
||||
return []
|
||||
yield msg_type, content
|
||||
|
||||
|
||||
# Main UI
|
||||
@@ -112,30 +87,14 @@ for item in st.session_state.chat_history:
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Type your message..."):
|
||||
# Add user message to history
|
||||
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
||||
|
||||
# Send message and get response
|
||||
with st.chat_message("assistant"):
|
||||
final_messages = send_message(prompt, st.container())
|
||||
|
||||
# Add response to history
|
||||
if final_messages:
|
||||
for msg in final_messages:
|
||||
msg_type = msg.get("type", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if msg_type == "tool":
|
||||
if isinstance(content, (dict, list)):
|
||||
content_str = json.dumps(content, indent=2)
|
||||
else:
|
||||
content_str = str(content)
|
||||
st.session_state.chat_history.append(
|
||||
{"role": "assistant", "content": content_str, "is_tool": True}
|
||||
)
|
||||
elif msg_type in ["ai", "assistant"]:
|
||||
st.session_state.chat_history.append(
|
||||
{"role": "assistant", "content": content, "is_tool": False}
|
||||
)
|
||||
for msg_type, content in stream_chat(prompt):
|
||||
if msg_type == "tool":
|
||||
st.session_state.chat_history.append({"role": "tool", "content": content})
|
||||
elif msg_type in ["ai", "assistant"]:
|
||||
st.session_state.chat_history.append(
|
||||
{"role": "assistant", "content": content}
|
||||
)
|
||||
|
||||
st.rerun()
|
||||
|
||||
Reference in New Issue
Block a user