Rename agent state to GeoAssistantState (#6)

This commit is contained in:
Daniel Wiesmann
2025-12-04 12:26:28 +00:00
committed by GitHub
parent 4a7a2c050a
commit 93b405cda4
4 changed files with 13 additions and 13 deletions
+2 -2
View File
@@ -2,7 +2,7 @@ import datetime
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents import create_agent from langchain.agents import create_agent
from geo_assistant.agent.state import AgentState from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.agent.llms import llm from geo_assistant.agent.llms import llm
from geo_assistant.tools.overture import get_overture_locations from geo_assistant.tools.overture import get_overture_locations
@@ -23,7 +23,7 @@ async def create_graph():
system_prompt=SYSTEM_PROMPT.format( system_prompt=SYSTEM_PROMPT.format(
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
), ),
state_schema=AgentState, state_schema=GeoAssistantState,
checkpointer=checkpointer, checkpointer=checkpointer,
) )
return graph return graph
+1 -1
View File
@@ -3,5 +3,5 @@ from geojson_pydantic import FeatureCollection
from typing import Optional from typing import Optional
class AgentState(BaseAgentState): class GeoAssistantState(BaseAgentState):
place: Optional[FeatureCollection] place: Optional[FeatureCollection]
+7 -7
View File
@@ -9,15 +9,15 @@ import logging
from pydantic import UUID4 from pydantic import UUID4
from geo_assistant.agent.graph import create_graph from geo_assistant.agent.graph import create_graph
from geo_assistant.agent.state import AgentState from geo_assistant.agent.state import GeoAssistantState
from geo_assistant.api.schemas.chat import ChatRequestBody, ChatResponse from geo_assistant.api.schemas.chat import ChatRequestBody, ChatResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Whitelist state fields that can be set by the user. # Whitelist state fields that can be set by the user.
# Note that these attrs need to be pydantic Fields and # Note that these attrs need to be pydantic Fields and
# need a description in the AgentState model. # need a description in the GeoAssistantState model.
UI_SET_FIELDS_WHITELIST = ["feature_collection", "messages"] UI_SET_FIELDS_WHITELIST = ["point", "messages"]
@asynccontextmanager @asynccontextmanager
@@ -39,7 +39,7 @@ app.add_middleware(
async def stream_chat( async def stream_chat(
ui_state_update: AgentState, ui_state_update: GeoAssistantState,
thread_id: UUID4, thread_id: UUID4,
chatbot: Any, chatbot: Any,
request: Request, request: Request,
@@ -61,8 +61,8 @@ async def stream_chat(
ui_messages = [] ui_messages = []
for key in vars_to_update.keys(): for key in vars_to_update.keys():
if hasattr(AgentState, key): if hasattr(GeoAssistantState, key):
field_info = getattr(AgentState, key) field_info = getattr(GeoAssistantState, key)
description = field_info.description if field_info else f"Field {key}" description = field_info.description if field_info else f"Field {key}"
if description: if description:
ui_messages.append( ui_messages.append(
@@ -95,7 +95,7 @@ async def stream_chat(
payload = update[agent] payload = update[agent]
if "feature_collection" not in payload: # TODO if "feature_collection" not in payload: # TODO
payload["feature_collection"] = None payload["feature_collection"] = None
state_payload = AgentState(**payload) state_payload = GeoAssistantState(**payload)
resp = ChatResponse(thread_id=str(thread_id), state=state_payload) resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
+3 -3
View File
@@ -1,12 +1,12 @@
from pydantic import BaseModel from pydantic import BaseModel
from geo_assistant.agent.state import AgentState from geo_assistant.agent.state import GeoAssistantState
class ChatRequestBody(BaseModel): class ChatRequestBody(BaseModel):
agent_state_input: AgentState agent_state_input: GeoAssistantState
thread_id: str thread_id: str
class ChatResponse(BaseModel): class ChatResponse(BaseModel):
thread_id: str thread_id: str
state: AgentState state: GeoAssistantState