From 93b405cda44c94ceb94f5079fb065de841db7c55 Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Thu, 4 Dec 2025 12:26:28 +0000 Subject: [PATCH] Rename agent state to GeoAssistantState (#6) --- src/geo_assistant/agent/graph.py | 4 ++-- src/geo_assistant/agent/state.py | 2 +- src/geo_assistant/api/app.py | 14 +++++++------- src/geo_assistant/api/schemas/chat.py | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/geo_assistant/agent/graph.py b/src/geo_assistant/agent/graph.py index 6b1f21d..9625f7a 100644 --- a/src/geo_assistant/agent/graph.py +++ b/src/geo_assistant/agent/graph.py @@ -2,7 +2,7 @@ import datetime from langgraph.checkpoint.memory import InMemorySaver from langchain.agents import create_agent -from geo_assistant.agent.state import AgentState +from geo_assistant.agent.state import GeoAssistantState from geo_assistant.agent.llms import llm from geo_assistant.tools.overture import get_overture_locations @@ -23,7 +23,7 @@ async def create_graph(): system_prompt=SYSTEM_PROMPT.format( now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") ), - state_schema=AgentState, + state_schema=GeoAssistantState, checkpointer=checkpointer, ) return graph diff --git a/src/geo_assistant/agent/state.py b/src/geo_assistant/agent/state.py index d22ad72..8431f26 100644 --- a/src/geo_assistant/agent/state.py +++ b/src/geo_assistant/agent/state.py @@ -3,5 +3,5 @@ from geojson_pydantic import FeatureCollection from typing import Optional -class AgentState(BaseAgentState): +class GeoAssistantState(BaseAgentState): place: Optional[FeatureCollection] diff --git a/src/geo_assistant/api/app.py b/src/geo_assistant/api/app.py index 618a979..4b92892 100644 --- a/src/geo_assistant/api/app.py +++ b/src/geo_assistant/api/app.py @@ -9,15 +9,15 @@ import logging from pydantic import UUID4 from geo_assistant.agent.graph import create_graph -from geo_assistant.agent.state import AgentState +from geo_assistant.agent.state import GeoAssistantState from geo_assistant.api.schemas.chat import ChatRequestBody, ChatResponse logger = logging.getLogger(__name__) # Whitelist state fields that can be set by the user. # Note that these attrs need to be pydantic Fields and -# need a description in the AgentState model. -UI_SET_FIELDS_WHITELIST = ["feature_collection", "messages"] +# need a description in the GeoAssistantState model. +UI_SET_FIELDS_WHITELIST = ["point", "messages"] @asynccontextmanager @@ -39,7 +39,7 @@ app.add_middleware( async def stream_chat( - ui_state_update: AgentState, + ui_state_update: GeoAssistantState, thread_id: UUID4, chatbot: Any, request: Request, @@ -61,8 +61,8 @@ async def stream_chat( ui_messages = [] for key in vars_to_update.keys(): - if hasattr(AgentState, key): - field_info = getattr(AgentState, key) + if hasattr(GeoAssistantState, key): + field_info = getattr(GeoAssistantState, key) description = field_info.description if field_info else f"Field {key}" if description: ui_messages.append( @@ -95,7 +95,7 @@ async def stream_chat( payload = update[agent] if "feature_collection" not in payload: # TODO payload["feature_collection"] = None - state_payload = AgentState(**payload) + state_payload = GeoAssistantState(**payload) resp = ChatResponse(thread_id=str(thread_id), state=state_payload) diff --git a/src/geo_assistant/api/schemas/chat.py b/src/geo_assistant/api/schemas/chat.py index 26f35c6..ade45c0 100644 --- a/src/geo_assistant/api/schemas/chat.py +++ b/src/geo_assistant/api/schemas/chat.py @@ -1,12 +1,12 @@ from pydantic import BaseModel -from geo_assistant.agent.state import AgentState +from geo_assistant.agent.state import GeoAssistantState class ChatRequestBody(BaseModel): - agent_state_input: AgentState + agent_state_input: GeoAssistantState thread_id: str class ChatResponse(BaseModel): thread_id: str - state: AgentState + state: GeoAssistantState