mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
e3373026d6
* Use pydocstyle (D) rule with google convention Add a ruff rule to catch missing documentation. Using google convention so that undocumented-param (D417) rule is enabled to catch missing params, xref https://docs.astral.sh/ruff/rules/undocumented-param. Extended to include D213 (instead of D212) and D410 rules too. * Fix D100 Missing docstring in public module * Fix D101 Missing docstring in public class * Fix D103 Missing docstring in public function Also ignore rule D205 to allow first sentence of docstring to wrap to multiple lines. * Fix D417 Missing argument description in the docstring * Update indent in pyproject.toml file
57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
"""Tests for chat API endpoint."""
|
|
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from geo_assistant.agent.graph import create_graph
|
|
from geo_assistant.api.app import app
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def initialized_app():
|
|
"""Initialize the app's chatbot before testing."""
|
|
# Manually initialize the chatbot as the lifespan would
|
|
app.state.chatbot = await create_graph()
|
|
yield app
|
|
# Cleanup if needed
|
|
if hasattr(app.state, "chatbot"):
|
|
del app.state.chatbot
|
|
|
|
|
|
@pytest.mark.xfail
|
|
async def test_call_api(initialized_app):
|
|
"""Test calling the API at the /chat HTTP POST endpoint."""
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=initialized_app),
|
|
base_url="http://test",
|
|
) as client:
|
|
thread_id = uuid4()
|
|
response = await client.post(
|
|
"/chat",
|
|
json={
|
|
"agent_state_input": {
|
|
"messages": [
|
|
{
|
|
"content": "Find The Whitney Hotel Boston and buffer 0.1km around it, then fetch the NAIP imagery for the area from 2021 and summarize the contents of the image.",
|
|
"type": "human",
|
|
},
|
|
],
|
|
"place": None,
|
|
"search_area": None,
|
|
},
|
|
"thread_id": str(thread_id),
|
|
},
|
|
)
|
|
print(response)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "application/x-ndjson; charset=utf-8"
|
|
|
|
# Read the streaming response
|
|
content = response.text
|
|
assert content is not None
|
|
assert len(content) > 0
|