mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Add satellite image summarization tool (#7)
* Add dspy & jupyterlab as dependency * Add image summarizer agent tool * Add test for summarize tool * Remove try except --------- Co-authored-by: Daniel Wiesmann <yellowcap@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bcc331bd5e
commit
2d34ee0a16
@@ -208,3 +208,7 @@ __marimo__/
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# Notebook
|
||||
.ipynb_checkpoints
|
||||
nbs/*
|
||||
|
||||
@@ -17,6 +17,7 @@ dependencies = [
|
||||
"python-dotenv",
|
||||
"duckdb",
|
||||
"shapely",
|
||||
"dspy>=3.0.4",
|
||||
"watchdog>=6.0.0",
|
||||
]
|
||||
|
||||
@@ -26,6 +27,7 @@ dev = [
|
||||
"pytest",
|
||||
"pytest-asyncio",
|
||||
"pre-commit",
|
||||
"jupyterlab>=4.5.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
"""Tools for summarizing satellite images using LLM-based analysis."""
|
||||
|
||||
from typing import Annotated, Optional
|
||||
import dspy
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
|
||||
|
||||
class SatImgSummary(dspy.Signature):
|
||||
"Describe things you see in the satellite image."
|
||||
|
||||
img: dspy.Image = dspy.InputField(desc="A satellite image")
|
||||
answer: str = dspy.OutputField(desc="Description of the image")
|
||||
|
||||
|
||||
class SatImgSummaryAgent(dspy.Module):
|
||||
"""Agent for generating summaries of satellite images using an LLM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "ministral-3:14b-cloud",
|
||||
api_base: str = "http://localhost:11434",
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 4_096,
|
||||
) -> None:
|
||||
"""Initialize the satellite image summary agent.
|
||||
|
||||
Args:
|
||||
model: The Ollama model to use for summarization
|
||||
api_base: Base URL for the Ollama API
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
"""
|
||||
super().__init__()
|
||||
self.ollama_model = dspy.LM(
|
||||
model=f"ollama/{model}",
|
||||
api_base=api_base,
|
||||
api_key="",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
dspy.configure(lm=self.ollama_model)
|
||||
self.summarizer = dspy.Predict(SatImgSummary)
|
||||
|
||||
def forward(self, img_url: str) -> dspy.Prediction:
|
||||
"""Generate a summary for the given image URL.
|
||||
|
||||
Args:
|
||||
img_url: URL of the image to summarize
|
||||
|
||||
Returns:
|
||||
dspy.Prediction containing the image summary
|
||||
"""
|
||||
return self.summarizer(img=dspy.Image(img_url))
|
||||
|
||||
|
||||
# Singleton instance to avoid repeated initialization
|
||||
_SUMMARIZER_AGENT = SatImgSummaryAgent()
|
||||
|
||||
|
||||
@tool
|
||||
def summarize_sat_img(
|
||||
img_url: str,
|
||||
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
|
||||
) -> Command:
|
||||
"""Summarize the contents of a satellite image using an LLM.
|
||||
|
||||
Args:
|
||||
img_url: URL of the satellite image to analyze
|
||||
tool_call_id: Optional ID for tracking the tool call
|
||||
|
||||
Returns:
|
||||
Command containing the image summary and metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If the image URL is invalid or the image cannot be processed
|
||||
"""
|
||||
if not img_url or not isinstance(img_url, str):
|
||||
raise ValueError("img_url must be a non-empty string")
|
||||
|
||||
summary = _SUMMARIZER_AGENT(img_url)
|
||||
message_content = summary.answer
|
||||
artifact = {"img_url": img_url}
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=message_content,
|
||||
artifact=artifact,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Tests for the satellite image summarization tool."""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
|
||||
from langchain_core.tools.base import ToolCall
|
||||
from geo_assistant.tools.summarize import summarize_sat_img
|
||||
|
||||
# Sample test data
|
||||
TEST_IMAGE_URL = "https://petapixel.com/assets/uploads/2022/08/French-Officials-Use-Satellite-Photos-and-AI-to-Spot-Unregistered-Pools-1536x806.jpg"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"img_url,summary",
|
||||
[
|
||||
(TEST_IMAGE_URL, "building"),
|
||||
],
|
||||
)
|
||||
def test_summarize_sat_img(img_url, summary):
|
||||
command = summarize_sat_img.invoke(
|
||||
ToolCall(
|
||||
name="summarize_sat_img",
|
||||
type="tool_call",
|
||||
args={"img_url": img_url},
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
)
|
||||
|
||||
print(command.update.get("messages"))
|
||||
assert summary in command.update.get("messages")[-1].content
|
||||
Reference in New Issue
Block a user