mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-15 23:41:01 +02:00
Add satellite image summarization tool (#7)
* Add dspy & jupyterlab as dependency * Add image summarizer agent tool * Add test for summarize tool * Remove try except --------- Co-authored-by: Daniel Wiesmann <yellowcap@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bcc331bd5e
commit
2d34ee0a16
@@ -208,3 +208,7 @@ __marimo__/
|
|||||||
|
|
||||||
# VSCode
|
# VSCode
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|
||||||
|
# Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
nbs/*
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ dependencies = [
|
|||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"duckdb",
|
"duckdb",
|
||||||
"shapely",
|
"shapely",
|
||||||
|
"dspy>=3.0.4",
|
||||||
"watchdog>=6.0.0",
|
"watchdog>=6.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -26,6 +27,7 @@ dev = [
|
|||||||
"pytest",
|
"pytest",
|
||||||
"pytest-asyncio",
|
"pytest-asyncio",
|
||||||
"pre-commit",
|
"pre-commit",
|
||||||
|
"jupyterlab>=4.5.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
"""Tools for summarizing satellite images using LLM-based analysis."""
|
||||||
|
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
import dspy
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langgraph.types import Command
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langchain_core.tools.base import InjectedToolCallId
|
||||||
|
|
||||||
|
|
||||||
|
class SatImgSummary(dspy.Signature):
|
||||||
|
"Describe things you see in the satellite image."
|
||||||
|
|
||||||
|
img: dspy.Image = dspy.InputField(desc="A satellite image")
|
||||||
|
answer: str = dspy.OutputField(desc="Description of the image")
|
||||||
|
|
||||||
|
|
||||||
|
class SatImgSummaryAgent(dspy.Module):
|
||||||
|
"""Agent for generating summaries of satellite images using an LLM."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "ministral-3:14b-cloud",
|
||||||
|
api_base: str = "http://localhost:11434",
|
||||||
|
temperature: float = 0.5,
|
||||||
|
max_tokens: int = 4_096,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the satellite image summary agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The Ollama model to use for summarization
|
||||||
|
api_base: Base URL for the Ollama API
|
||||||
|
temperature: Sampling temperature (0-1)
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.ollama_model = dspy.LM(
|
||||||
|
model=f"ollama/{model}",
|
||||||
|
api_base=api_base,
|
||||||
|
api_key="",
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
dspy.configure(lm=self.ollama_model)
|
||||||
|
self.summarizer = dspy.Predict(SatImgSummary)
|
||||||
|
|
||||||
|
def forward(self, img_url: str) -> dspy.Prediction:
|
||||||
|
"""Generate a summary for the given image URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_url: URL of the image to summarize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dspy.Prediction containing the image summary
|
||||||
|
"""
|
||||||
|
return self.summarizer(img=dspy.Image(img_url))
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance to avoid repeated initialization
|
||||||
|
_SUMMARIZER_AGENT = SatImgSummaryAgent()
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def summarize_sat_img(
|
||||||
|
img_url: str,
|
||||||
|
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
|
||||||
|
) -> Command:
|
||||||
|
"""Summarize the contents of a satellite image using an LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_url: URL of the satellite image to analyze
|
||||||
|
tool_call_id: Optional ID for tracking the tool call
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command containing the image summary and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the image URL is invalid or the image cannot be processed
|
||||||
|
"""
|
||||||
|
if not img_url or not isinstance(img_url, str):
|
||||||
|
raise ValueError("img_url must be a non-empty string")
|
||||||
|
|
||||||
|
summary = _SUMMARIZER_AGENT(img_url)
|
||||||
|
message_content = summary.answer
|
||||||
|
artifact = {"img_url": img_url}
|
||||||
|
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=message_content,
|
||||||
|
artifact=artifact,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
"""Tests for the satellite image summarization tool."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from langchain_core.tools.base import ToolCall
|
||||||
|
from geo_assistant.tools.summarize import summarize_sat_img
|
||||||
|
|
||||||
|
# Sample test data
|
||||||
|
TEST_IMAGE_URL = "https://petapixel.com/assets/uploads/2022/08/French-Officials-Use-Satellite-Photos-and-AI-to-Spot-Unregistered-Pools-1536x806.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"img_url,summary",
|
||||||
|
[
|
||||||
|
(TEST_IMAGE_URL, "building"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_summarize_sat_img(img_url, summary):
|
||||||
|
command = summarize_sat_img.invoke(
|
||||||
|
ToolCall(
|
||||||
|
name="summarize_sat_img",
|
||||||
|
type="tool_call",
|
||||||
|
args={"img_url": img_url},
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(command.update.get("messages"))
|
||||||
|
assert summary in command.update.get("messages")[-1].content
|
||||||
Reference in New Issue
Block a user