mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Rename agent state to GeoAssistantState (#6)
This commit is contained in:
@@ -2,7 +2,7 @@ import datetime
|
|||||||
|
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from geo_assistant.agent.state import AgentState
|
from geo_assistant.agent.state import GeoAssistantState
|
||||||
from geo_assistant.agent.llms import llm
|
from geo_assistant.agent.llms import llm
|
||||||
from geo_assistant.tools.overture import get_overture_locations
|
from geo_assistant.tools.overture import get_overture_locations
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ async def create_graph():
|
|||||||
system_prompt=SYSTEM_PROMPT.format(
|
system_prompt=SYSTEM_PROMPT.format(
|
||||||
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
now=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
),
|
),
|
||||||
state_schema=AgentState,
|
state_schema=GeoAssistantState,
|
||||||
checkpointer=checkpointer,
|
checkpointer=checkpointer,
|
||||||
)
|
)
|
||||||
return graph
|
return graph
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ from geojson_pydantic import FeatureCollection
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class AgentState(BaseAgentState):
|
class GeoAssistantState(BaseAgentState):
|
||||||
place: Optional[FeatureCollection]
|
place: Optional[FeatureCollection]
|
||||||
|
|||||||
@@ -9,15 +9,15 @@ import logging
|
|||||||
from pydantic import UUID4
|
from pydantic import UUID4
|
||||||
|
|
||||||
from geo_assistant.agent.graph import create_graph
|
from geo_assistant.agent.graph import create_graph
|
||||||
from geo_assistant.agent.state import AgentState
|
from geo_assistant.agent.state import GeoAssistantState
|
||||||
from geo_assistant.api.schemas.chat import ChatRequestBody, ChatResponse
|
from geo_assistant.api.schemas.chat import ChatRequestBody, ChatResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Whitelist state fields that can be set by the user.
|
# Whitelist state fields that can be set by the user.
|
||||||
# Note that these attrs need to be pydantic Fields and
|
# Note that these attrs need to be pydantic Fields and
|
||||||
# need a description in the AgentState model.
|
# need a description in the GeoAssistantState model.
|
||||||
UI_SET_FIELDS_WHITELIST = ["feature_collection", "messages"]
|
UI_SET_FIELDS_WHITELIST = ["point", "messages"]
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -39,7 +39,7 @@ app.add_middleware(
|
|||||||
|
|
||||||
|
|
||||||
async def stream_chat(
|
async def stream_chat(
|
||||||
ui_state_update: AgentState,
|
ui_state_update: GeoAssistantState,
|
||||||
thread_id: UUID4,
|
thread_id: UUID4,
|
||||||
chatbot: Any,
|
chatbot: Any,
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -61,8 +61,8 @@ async def stream_chat(
|
|||||||
|
|
||||||
ui_messages = []
|
ui_messages = []
|
||||||
for key in vars_to_update.keys():
|
for key in vars_to_update.keys():
|
||||||
if hasattr(AgentState, key):
|
if hasattr(GeoAssistantState, key):
|
||||||
field_info = getattr(AgentState, key)
|
field_info = getattr(GeoAssistantState, key)
|
||||||
description = field_info.description if field_info else f"Field {key}"
|
description = field_info.description if field_info else f"Field {key}"
|
||||||
if description:
|
if description:
|
||||||
ui_messages.append(
|
ui_messages.append(
|
||||||
@@ -95,7 +95,7 @@ async def stream_chat(
|
|||||||
payload = update[agent]
|
payload = update[agent]
|
||||||
if "feature_collection" not in payload: # TODO
|
if "feature_collection" not in payload: # TODO
|
||||||
payload["feature_collection"] = None
|
payload["feature_collection"] = None
|
||||||
state_payload = AgentState(**payload)
|
state_payload = GeoAssistantState(**payload)
|
||||||
|
|
||||||
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
|
resp = ChatResponse(thread_id=str(thread_id), state=state_payload)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from geo_assistant.agent.state import AgentState
|
from geo_assistant.agent.state import GeoAssistantState
|
||||||
|
|
||||||
|
|
||||||
class ChatRequestBody(BaseModel):
|
class ChatRequestBody(BaseModel):
|
||||||
agent_state_input: AgentState
|
agent_state_input: GeoAssistantState
|
||||||
thread_id: str
|
thread_id: str
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
thread_id: str
|
thread_id: str
|
||||||
state: AgentState
|
state: GeoAssistantState
|
||||||
|
|||||||
Reference in New Issue
Block a user