mirror of
https://github.com/dataforcanada/d4c-service-geo-assistant.git
synced 2026-06-13 14:31:01 +02:00
Add satellite image summarization tool (#7)
* Add dspy & jupyterlab as dependency * Add image summarizer agent tool * Add test for summarize tool * Remove try except --------- Co-authored-by: Daniel Wiesmann <yellowcap@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bcc331bd5e
commit
2d34ee0a16
@@ -0,0 +1,97 @@
|
||||
"""Tools for summarizing satellite images using LLM-based analysis."""
|
||||
|
||||
from typing import Annotated, Optional
|
||||
import dspy
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import InjectedToolCallId
|
||||
|
||||
|
||||
class SatImgSummary(dspy.Signature):
|
||||
"Describe things you see in the satellite image."
|
||||
|
||||
img: dspy.Image = dspy.InputField(desc="A satellite image")
|
||||
answer: str = dspy.OutputField(desc="Description of the image")
|
||||
|
||||
|
||||
class SatImgSummaryAgent(dspy.Module):
|
||||
"""Agent for generating summaries of satellite images using an LLM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "ministral-3:14b-cloud",
|
||||
api_base: str = "http://localhost:11434",
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 4_096,
|
||||
) -> None:
|
||||
"""Initialize the satellite image summary agent.
|
||||
|
||||
Args:
|
||||
model: The Ollama model to use for summarization
|
||||
api_base: Base URL for the Ollama API
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
"""
|
||||
super().__init__()
|
||||
self.ollama_model = dspy.LM(
|
||||
model=f"ollama/{model}",
|
||||
api_base=api_base,
|
||||
api_key="",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
dspy.configure(lm=self.ollama_model)
|
||||
self.summarizer = dspy.Predict(SatImgSummary)
|
||||
|
||||
def forward(self, img_url: str) -> dspy.Prediction:
|
||||
"""Generate a summary for the given image URL.
|
||||
|
||||
Args:
|
||||
img_url: URL of the image to summarize
|
||||
|
||||
Returns:
|
||||
dspy.Prediction containing the image summary
|
||||
"""
|
||||
return self.summarizer(img=dspy.Image(img_url))
|
||||
|
||||
|
||||
# Singleton instance to avoid repeated initialization
|
||||
_SUMMARIZER_AGENT = SatImgSummaryAgent()
|
||||
|
||||
|
||||
@tool
|
||||
def summarize_sat_img(
|
||||
img_url: str,
|
||||
tool_call_id: Annotated[Optional[str], InjectedToolCallId] = None,
|
||||
) -> Command:
|
||||
"""Summarize the contents of a satellite image using an LLM.
|
||||
|
||||
Args:
|
||||
img_url: URL of the satellite image to analyze
|
||||
tool_call_id: Optional ID for tracking the tool call
|
||||
|
||||
Returns:
|
||||
Command containing the image summary and metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If the image URL is invalid or the image cannot be processed
|
||||
"""
|
||||
if not img_url or not isinstance(img_url, str):
|
||||
raise ValueError("img_url must be a non-empty string")
|
||||
|
||||
summary = _SUMMARIZER_AGENT(img_url)
|
||||
message_content = summary.answer
|
||||
artifact = {"img_url": img_url}
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=message_content,
|
||||
artifact=artifact,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user