mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-14 15:01:01 +02:00
Rename agent state to GeoAssistantState (#6)
This commit is contained in:
@@ -9,15 +9,15 @@ import logging
|
||||
from pydantic import UUID4
|
||||
|
||||
from geo_assistant.agent.graph import create_graph
|
||||
from geo_assistant.agent.state import AgentState
|
||||
from geo_assistant.agent.state import GeoAssistantState
|
||||
from geo_assistant.api.schemas.chat import ChatRequestBody, ChatResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Whitelist state fields that can be set by the user.
|
||||
# Note that these attrs need to be pydantic Fields and
|
||||
# need a description in the AgentState model.
|
||||
UI_SET_FIELDS_WHITELIST = ["feature_collection", "messages"]
|
||||
# need a description in the GeoAssistantState model.
|
||||
UI_SET_FIELDS_WHITELIST = ["point", "messages"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -39,7 +39,7 @@ app.add_middleware(
|
||||
|
||||
|
||||
async def stream_chat(
|
||||
ui_state_update: AgentState,
|
||||
ui_state_update: GeoAssistantState,
|
||||
thread_id: UUID4,
|
||||
chatbot: Any,
|
||||
request: Request,
|
||||
@@ -61,8 +61,8 @@ async def stream_chat(
|
||||
|
||||
ui_messages = []
|
||||
for key in vars_to_update.keys():
|
||||
if hasattr(AgentState, key):
|
||||
field_info = getattr(AgentState, key)
|
||||
if hasattr(GeoAssistantState, key):
|
||||
field_info = getattr(GeoAssistantState, key)
|
||||
description = field_info.description if field_info else f"Field {key}"
|
||||
if description:
|
||||
ui_messages.append(
|
||||
@@ -95,7 +95,7 @@ async def stream_chat(
|
||||
payload = update[agent]
|
||||
if "feature_collection" not in payload: # TODO
|
||||
payload["feature_collection"] = None
|
||||
state_payload = AgentState(**payload)
|
||||
state_payload = GeoAssistantState(**payload)
|
||||
|
||||
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user