Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Last active February 8, 2025 13:18
Show Gist options
  • Save grahama1970/9e7096b608a6a4692f8acb3391c52aba to your computer and use it in GitHub Desktop.
Save grahama1970/9e7096b608a6a4692f8acb3391c52aba to your computer and use it in GitHub Desktop.
This README details how SmolAgents leverages async tools like SummarizationTool for efficient LitelLLM processing. It explains handling long text with chunking, parallel API calls, and synthesis while maintaining a simple agent interface. Async tools improve performance, cost efficiency, and reliability. πŸš€

πŸ› οΈ Async Tools for SmolAgents (LiteLLM)

This directory contains tools for handling LLM operationsβ€”most notably for document summarizationβ€”using asynchronous (async) techniques. Async Tools allow SmolAgents to process large inputs efficiently by splitting work into smaller chunks and running parallel API calls, all while keeping integration with the agent framework simple.


πŸ“ Preventing Context Length Issues

The Problem

Large text documents can exceed the fixed context window of LLMs (e.g. GPT-4), which may result in:

  • Truncation of important information
  • Incomplete processing of input
  • Inefficient token usage

The Async-Based Solution: Rolling Window Summarization

The SummarizationTool handles these challenges by:

  1. Chunking the Text: Splitting long documents into manageable pieces.
  2. Parallel Processing: Using asynchronous calls (via asyncio.gather) to summarize chunks concurrently.
  3. Synthesis: Combining individual summaries into a coherent final summary.
  4. Clear Feedback: Returning a structured response that includes both the summary and process β€œthoughts.”

Example system prompt for tool selection:

system_prompt = (
    "IMPORTANT: If you receive a request to summarize large text, "
    "**DO NOT** attempt to process it yourself.\n"
    "Instead, use the 'summarization_tool' to summarize the text.\n"
)

⚑ Asynchronous Processing in Action
The Challenge with Synchronous Tools

In the standard (synchronous) setup, handling large documents may require multiple sequential LLM calls, which can lead to:

    Increased latency
    Greater chance of hitting rate limits
    Higher cost due to redundant API calls

The Async Approach

By embedding async logic within tools like SummarizationTool, SmolAgents can:

    Parallelize Work: Process hundreds of text chunks concurrently.
    Improve Efficiency: Reduce overall wait times by not waiting for each LLM call to complete before starting the next.
    Enhance Reliability: Gracefully handle retries and caching, especially under rate limits.
    Maintain a Simple Interface: Present a synchronous API to the agent (via asyncio.run), while internally managing complex async workflows.

Example async flow in SummarizationTool:

class SummarizationTool(Tool):
    def forward(self, text: str) -> Dict:
        # Called synchronously by the CodeAgent
        result = asyncio.run(self._run_summary(text))
        return result

    async def _run_summary(self, text: str) -> Dict:
        # Split text into chunks and process them in parallel
        chunks = self.chunk_text(text)
        tasks = [self.process_chunk(chunk) for chunk in chunks]
        results = await asyncio.gather(*tasks)
        final_summary = await self.synthesize_results(results)
        return {"summary": final_summary, "thoughts": "Summarization completed asynchronously."}

πŸ”— Integration with SmolAgents

SmolAgents are designed to automatically select and integrate tools based on the task. For example, the CodeAgent includes a decision function that selects the summarization_tool when a given text exceeds a token threshold:

def decide_tool(self, task_text: str) -> Optional[str]:
    TOKEN_THRESHOLD = 1500  # Adjust threshold as needed
    if estimate_token_count(task_text) > TOKEN_THRESHOLD:
        return "summarization_tool"
    return None

When a long text is detected, the agent delegates summarization to the async tool, ensuring that the LLM’s context limits are respected without extra manual intervention.
πŸ’‘ Benefits of the Async Approach

Using async tools within SmolAgents provides several tangible benefits:

    Improved Performance: Parallel processing of text chunks reduces latency.
    Cost Efficiency: Caching and concurrent calls help optimize token usage.
    Scalability: Better suited for processing large documents or handling multiple requests.
    Enhanced Reliability: Built-in retries and error handling improve robustness.

πŸ“ Usage Example

Below is an example of initializing and using the SummarizationTool within a CodeAgent:

from smolagents import CodeAgent
from web_browser.tools.llm_tools.summarization_tool import SummarizationTool
from smolagents.models import LiteLLMModel

# Initialize the summarization tool
tools = [SummarizationTool()]

# Create a CodeAgent with the async tool integrated
agent = CodeAgent(
    tools=tools,
    model=LiteLLMModel(model_id="openai/gpt-4"),
    max_steps=3,
    verbosity_level=4,
)

# Build and assign the system prompt
agent.system_prompt = build_code_agent_system_prompt(tools)

# Process a long document
long_document = " ".join(["This is a test sentence."] * 1000)
response = agent.run("Please summarize this text:\n" + long_document)

print("Summary Response:")
print(response)

In this setup:

    The agent automatically selects the async summarization tool based on input length.
    The tool processes large inputs by chunking and parallel asynchronous calls.
    The final summary and process details (thoughts) are returned in a structured format.

βš™οΈ Configuration & Customization

The summarization tool can be configured to adjust its behavior:

    Chunk Size & Overlap: Control the granularity of text segmentation.
    LLM Parameters: Specify model settings and API options.
    Caching & Retries: Optimize performance and handle transient errors.

Refer to the configuration settings in SUMMARIZATION_CONFIG for detailed options.

βœ… Testing and Error Handling

The tests (see test_summarization_tool.py) verify:

    Integration with the CodeAgent
    Correct output structure (i.e. a dictionary with summary and thoughts)
    Robust error handling for empty, None, or very short text inputs

These tests ensure that the async summarization behaves as expected in various scenarios without exceeding the LLM’s context limits.

By leveraging async tools within SmolAgents, you can process large documents more efficiently while maintaining a simple interface for agent interactions. This approach is designed to be flexible and configurable, making it suitable for a range of LLM
import asyncio
import json
import jsonpickle
import time
from typing import Any, Dict, List, Optional, Union
import nltk
from deepmerge import always_merger
from loguru import logger
from pydantic import BaseModel, Field, field_validator, ValidationError
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from smolagents import LiteLLMModel, Tool
from litellm import acompletion
from web_browser.tools.llm_tools.utils.estimate_token_count import estimate_token_count
from web_browser.tools.llm_tools.utils.initialize_litellm_cache import initialize_litellm_cache
from web_browser.tools.llm_tools.utils.summarizer_config import SUMMARIZATION_CONFIG
from web_browser.utils.agent_utils import build_code_agent_system_prompt, collect_authorized_imports
from web_browser.utils.json_utils import clean_json_string
from web_browser.utils.file_utils import get_project_root, load_env_file
import nest_asyncio
# =====================
# Configuration Classes
# =====================
class LLMResponse(BaseModel):
"""
Schema for validating the structured LLM response.
Ensures the LLM returns a JSON object with a `question` and `answer`.
Also includes summary metadata from the _run_summary method.
"""
summary: str = Field(..., description="The final summarized content")
token_count: int = Field(..., description="Number of tokens in the summary")
model: str = Field(..., description="LLM model used for summarization")
total_cost: float = Field(..., description="Total cost of summarization (all chunks + final synthesis)")
total_duration: float = Field(..., description="Total time taken for summarization (all chunks + final synthesis)")
cache_hit_rate: float = Field(..., ge=0, le=1, description="Cache hit rate for summarization (0 to 1)")
chunk_metrics: List[Dict[str, Any]] = Field(
default_factory=list,
description="Metrics for each chunk summarization (model, cost, duration, cache_hit)",
)
final_metrics: Dict[str, Any] = Field(
...,
description="Metrics for the final synthesis (model, cost, duration, cache_hit)",
)
@field_validator("chunk_metrics", mode="before")
@classmethod
def validate_chunk_metrics(cls, v: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Validates that each chunk metric contains the required fields."""
required_fields = {"model", "cost", "duration", "cache_hit"}
for metric in v:
if not required_fields.issubset(metric.keys()):
raise ValueError(f"Chunk metric missing required fields: {required_fields - metric.keys()}")
return v
@field_validator("final_metrics", mode="before")
@classmethod
def validate_final_metrics(cls, v: Dict[str, Any]) -> Dict[str, Any]:
"""Validates that the final metrics contain the required fields."""
required_fields = {"model", "cost", "duration", "cache_hit"}
if not required_fields.issubset(v.keys()):
raise ValueError(f"Final metric missing required fields: {required_fields - v.keys()}")
return v
class SummarizationConfig(BaseModel):
chunk_size: int = Field(..., gt=0, description="Size of text chunks in tokens")
overlap_size: int = Field(
default=0, ge=0, description="Overlap between chunks in tokens"
)
llm_params: dict = Field(..., description="LLM parameters including model name")
max_retries: int = Field(
default=3, ge=0, description="Max retry attempts per chunk"
)
output: dict = Field(
default_factory=lambda: {"include_tokens": True, "include_model": True},
description="Output formatting options",
)
@field_validator("llm_params", mode="before")
@classmethod
def validate_llm_model(cls, v: dict) -> dict:
if "model" not in v:
raise ValueError("llm_params must contain 'model' key")
return v
# =====================
# Summarization Tool
# =====================
class SummarizationTool(Tool):
"""
Hierarchical document summarization tool.
Config Structure:
{
"chunk_size": int > 0,
"overlap_size": int >= 0,
"llm_params": {
"model": str,
...other LLM params
},
"max_retries": int >= 0
}
"""
name = "summarization_tool"
description = "Summarizes large documents using chunking and multi-step synthesis"
inputs = {
"text": {
"type": "string",
"description": "Full text content to summarize (can be longer than the model context limit). "
"Do not truncate or modify before sending."
},
"config": {
"type": "object",
"description": "Configuration for summarization process",
"default": SUMMARIZATION_CONFIG["default"],
"nullable": True,
},
}
output_type = "object"
default_model = "openai/gpt-4o-mini"
@property
def system_prompt(self) -> str:
return (
"IMPORTANT: ALWAYS use the 'summarization_tool' for ANY text summarization task. "
"You MUST return a summary in a structured format. "
"Ensure that you utilize the tool properly and follow the required execution format."
)
@property
def authorized_imports(self) -> List[str]:
return ["nltk", "asyncio", "litellm"]
def __init__(self, config: Optional[Dict] = None, **kwargs):
super().__init__(**kwargs)
# Merge provided config with defaults.
merged_config = always_merger.merge(
SUMMARIZATION_CONFIG["default"], config or {}
)
self.config = SummarizationConfig(**merged_config)
self.model_id = kwargs.get("model", self.default_model)
# Initialize caches and required resources.
# initialize_litellm_cache()
nltk.download("punkt", quiet=True)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type(Exception),
)
async def fetch_llm_response(self, prompt: str, system_prompt: str, response_format: Optional[str] = None) -> Dict:
"""
Fetches a raw LLM response asynchronously with retries.
"""
start_time = time.monotonic()
try:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
api_params = {
"model": self.model_id,
"messages": messages,
"temperature": 0.3,
"max_tokens": 1000,
"response_format": response_format,
"caching": True,
}
response = await acompletion(**api_params)
duration = time.monotonic() - start_time
# Calculate cost immediately while we have the raw response
cost = float(response._hidden_params.get('response_cost', 0.0)) if hasattr(response, '_hidden_params') else 0.0
cache_hit = bool(response._hidden_params.get('cache_hit', False)) if hasattr(response, '_hidden_params') else False
logger.debug(f"LLM Response - Cost: {cost}, Duration: {duration}, Cache Hit: {cache_hit}")
return {
"content": response["choices"][0]["message"]["content"],
"duration": duration,
"cost": cost,
"cache_hit": cache_hit,
"model": self.model_id
}
except Exception as e:
logger.error(f"Error in fetch_llm_response: {str(e)}")
raise
def format_dict_to_executable_code(self, result: Union[Dict[str, Any], List[Dict[str, Any]]]) -> str:
"""Formats results into executable Python code for SmolAgents.
Args:
result: Dictionary containing results or list of error dictionaries
Returns:
str: Formatted code blob that SmolAgents can execute
"""
# Fix 2: Handle both success and error cases
if isinstance(result, list):
# Error case
error_msg = result[0].get('error', 'Unknown error')
return (
f"Thoughts: Error occurred: {error_msg}\n\n"
"Code:\n"
"```py\n"
f"def run():\n return {json.dumps({'error': error_msg}, indent=4)}\n"
"```\n"
"<end_code>"
)
# Success case
function_code = (
"def run():\n"
f" return {json.dumps(result, indent=4)}"
)
code_blob = (
f"Thoughts: {result['thoughts']}\n\n"
"Code:\n"
f"```py\n{function_code}\n```\n"
"<end_code>"
)
return code_blob
def forward(self, text: str, config: Optional[Dict] = None) -> Dict:
"""Forward pass for the summarization tool.
Args:
text: Text to summarize
config: Optional configuration overrides
Returns:
Dict containing summary and thoughts
Raises:
ValueError: If text is empty or None
"""
if text is None:
logger.error("None text provided to forward")
raise ValueError("Text cannot be None")
logger.info(f"SummarizationTool.forward called with text: {text[:50]}...")
try:
# Run the async summarization
result = asyncio.run(self._run_summary(text, config))
# Return only the essential fields needed by SmolAgents
return {
"summary": result["summary"],
"thoughts": result["thoughts"]
}
except Exception as e:
logger.error(f"Error in forward method: {str(e)}")
if isinstance(e, ValueError):
raise # Re-raise ValueError for empty text
return {
"summary": f"Error occurred: {str(e)}",
"thoughts": f"Summarization failed: {str(e)}"
}
async def _run_summary(self, text: str, config: Optional[Dict] = None) -> Dict:
logger.info(f"_run_summary called with text length: {len(text)}")
if not text:
logger.error("Empty text provided to _run_summary")
raise ValueError("Empty text provided")
logger.info(f"SummarizationTool called with text: {text[:50]}...")
# Merge runtime config with base config.
merged_config = always_merger.merge(
self.config.model_dump(), config or {}
)
# If the text already fits within the max length, summarize directly
if estimate_token_count(text) <= merged_config["chunk_size"]:
start_time = time.monotonic()
response_text = await self.fetch_llm_response(
text,
"Summarize the following text:"
)
duration = time.monotonic() - start_time
cost = response_text["cost"]
token_count = estimate_token_count(response_text["content"])
# Initialize empty chunk metrics since we're not chunking
chunk_metrics = []
try:
response = LLMResponse.model_validate({
"summary": clean_json_string(response_text["content"]),
"token_count": token_count,
"model": self.model_id,
"total_cost": float(response_text["cost"]),
"total_duration": float(duration),
"cache_hit_rate": response_text["cache_hit"],
"chunk_metrics": [],
"final_metrics": {
"model": self.model_id,
"cost": float(response_text["cost"]),
"duration": float(duration),
"cache_hit": response_text["cache_hit"]
}
})
except ValidationError as e:
logger.error("Validation failed with errors:")
logger.error(e.errors())
logger.error("Input data was:")
logger.error({
"cost": response_text["cost"],
"duration_type": type(duration),
"model": self.model_id,
"token_count": token_count
})
raise
# Fix 1: Use .model_dump() to convert Pydantic model to dict
response_dict = response.model_dump()
# Construct thoughts
thoughts = (
f"Successfully summarized a {len(text)} character document. "
f"The text was processed in {len(chunk_metrics)} chunks due to length. "
f"Final summary is {token_count} tokens. "
f"Process achieved {response_text['cache_hit']*100:.1f}% cache hit rate "
f"and cost ${float(response_text['cost']):.4f}."
)
# Return dictionary with all fields
return {
"summary": response_dict["summary"],
"token_count": response_dict["token_count"],
"model": response_dict["model"],
"total_cost": response_dict["total_cost"],
"total_duration": response_dict["total_duration"],
"cache_hit_rate": response_dict["cache_hit_rate"],
"chunk_metrics": response_dict["chunk_metrics"],
"final_metrics": response_dict["final_metrics"],
"thoughts": thoughts
}
# Phase 1: Chunk the text.
sentences = nltk.sent_tokenize(text)
chunks = self._create_chunks_with_overlap(
sentences,
merged_config["chunk_size"],
merged_config["overlap_size"],
)
# Phase 2: Summarize each chunk in parallel.
tasks = [
self._summarize_chunk(" ".join(chunk), merged_config)
for chunk in chunks
]
chunk_results = await asyncio.gather(*tasks)
# Phase 3: Synthesize final summary from chunk summaries.
synthesized_input = "\n".join(r["content"] for r in chunk_results)
final_result = await self._synthesize_final_summary(
synthesized_input, merged_config, chunk_results
)
# Aggregate metrics.
total_cost = (
sum(r["cost"] for r in chunk_results) + final_result["total_cost"]
)
total_duration = (
sum(r["duration"] for r in chunk_results) + final_result["total_duration"]
)
cache_hits = sum(
1 for r in (chunk_results) if r["cache_hit"]
)
# Construct informative thoughts about the summarization process
thoughts = (
f"Successfully summarized a {len(text)} character document. "
f"The text was processed in {len(chunk_results)} chunks due to length. "
f"Final summary is {final_result['token_count']} tokens. "
f"Process achieved {cache_hits/(len(chunk_results) + 1)*100:.1f}% cache hit rate "
f"and cost ${total_cost:.4f}."
)
# Create the return dictionary with only what we want to pass to SmolAgents
return_dict = {
"summary": final_result["summary"],
"thoughts": thoughts
}
# Keep all metrics and processing details in a separate dictionary for logging/debugging
metrics = {
"token_count": final_result["token_count"],
"model": final_result["model"],
"total_cost": total_cost,
"total_duration": total_duration,
"cache_hit_rate": cache_hits / (len(chunk_results) + 1),
"chunk_metrics": [
{
"model": r["model"],
"cost": r["cost"],
"duration": r["duration"],
"cache_hit": r["cache_hit"],
}
for r in chunk_results
],
"final_metrics": {
"model": final_result["model"],
"cost": final_result["total_cost"],
"duration": final_result["total_duration"],
"cache_hit_rate": final_result["cache_hit_rate"],
}
}
logger.debug(f"Summary metrics: {metrics}")
return return_dict
def _create_chunks_with_overlap(
self, sentences: List[str], chunk_size: int, overlap_size: int
) -> List[List[str]]:
"""Creates chunks using a rolling window with token-based overlap."""
chunks: List[List[str]] = []
current_chunk: List[str] = []
current_tokens = 0
for sentence in sentences:
sentence_tokens = estimate_token_count(sentence)
# Handle oversized sentences (greater than chunk_size)
if sentence_tokens > chunk_size:
if current_chunk:
chunks.append(current_chunk)
chunks.append([sentence])
current_chunk = []
current_tokens = 0
continue
# Main chunk building logic
if current_tokens + sentence_tokens <= chunk_size:
current_chunk.append(sentence)
current_tokens += sentence_tokens
else:
chunks.append(current_chunk)
# Start new chunk with token-based overlap
overlap_tokens = 0
new_chunk = []
# Collect overlapping sentences within token budget
for s in reversed(current_chunk):
s_tokens = estimate_token_count(s)
if overlap_tokens + s_tokens > overlap_size:
break
new_chunk.insert(0, s)
overlap_tokens += s_tokens
current_chunk = new_chunk + [sentence]
current_tokens = overlap_tokens + sentence_tokens
# Validate new chunk doesn't exceed size
while current_tokens > chunk_size:
if len(current_chunk) > 1:
removed = current_chunk.pop(0)
current_tokens -= estimate_token_count(removed)
else:
chunks.append(current_chunk)
current_chunk = []
current_tokens = 0
break
# Final chunk validation
if current_chunk:
final_tokens = sum(estimate_token_count(s) for s in current_chunk)
if final_tokens > chunk_size:
chunks.extend(self._split_oversized_chunk(current_chunk, chunk_size))
else:
chunks.append(current_chunk)
return chunks
def _split_oversized_chunk(self, chunk: List[str], chunk_size: int) -> List[List[str]]:
"""Recursively splits chunks that exceed size limit after overlap handling."""
if not chunk:
return []
total_tokens = sum(estimate_token_count(s) for s in chunk)
if total_tokens <= chunk_size:
return [chunk]
# Find split point that maximizes chunk utilization
split_point = len(chunk) // 2
left_chunk = chunk[:split_point]
right_chunk = chunk[split_point:]
return (
self._split_oversized_chunk(left_chunk, chunk_size) +
self._split_oversized_chunk(right_chunk, chunk_size)
)
async def _summarize_chunk(
self,
chunk: str,
config: Dict[str, Any],
system_prompt: str = "Summarize the following text chunk:"
) -> Dict:
"""Summarizes a single chunk and collects telemetry."""
start_time = time.monotonic()
try:
response_text = await self.fetch_llm_response(chunk, system_prompt)
duration = time.monotonic() - start_time
cost = response_text["cost"]
return {
"content": clean_json_string(response_text["content"]),
"cost": cost,
"duration": duration,
"cache_hit": response_text["cache_hit"],
"model": self.model_id,
}
except Exception as e:
logger.error(f"Chunk summarization failed: {e}")
raise
async def _synthesize_final_summary(
self,
partial_summaries: str,
config: Dict[str, Any],
chunk_results: List[Dict[str, Any]]
) -> Dict:
"""Creates a final summary from chunk summaries."""
start_time = time.monotonic()
# Final synthesis step
response = await self.fetch_llm_response(
partial_summaries.strip(),
"Synthesize comprehensive final summary:"
)
final_duration = time.monotonic() - start_time
result = {
"summary": response["content"],
"token_count": estimate_token_count(response["content"]),
"model": self.model_id,
"total_cost": sum(r["cost"] for r in chunk_results) + response["cost"],
"total_duration": sum(r["duration"] for r in chunk_results) + final_duration,
"cache_hit_rate": (sum(1 for r in chunk_results if r["cache_hit"]) + int(response["cache_hit"])) / (len(chunk_results) + 1),
"chunk_metrics": [
{
"model": r["model"],
"cost": r["cost"],
"duration": r["duration"],
"cache_hit": r["cache_hit"],
}
for r in chunk_results
],
"final_metrics": {
"model": self.model_id,
"cost": response["cost"],
"duration": final_duration,
"cache_hit": response["cache_hit"],
}
}
return result
def _estimate_cost(self, response: Any) -> float:
"""
Estimates cost based on LiteLLM response.
"""
try:
# Handle our wrapped response format
if isinstance(response, dict) and "raw_response" in response:
import jsonpickle
raw_response = jsonpickle.decode(response["raw_response"])
if hasattr(raw_response, '_hidden_params'):
return float(raw_response._hidden_params.get('response_cost', 0.0))
elif isinstance(raw_response, dict) and '_hidden_params' in raw_response:
return float(raw_response['_hidden_params'].get('response_cost', 0.0))
return 0.0
except Exception as e:
logger.warning(f"Failed to get cost from response: {e}")
return 0.0
# =====================
# Usage Example with CodeAgent
# =====================
@retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(3))
def run_agent_with_retry(agent, prompt):
return agent.run(prompt)
if __name__ == "__main__":
import os
import wikipedia
from web_browser.tools.llm_tools.utils.load_wikipedia_text import load_wikipedia_text
from smolagents import CodeAgent
from smolagents.models import LiteLLMModel
from web_browser.tools.llm_tools.llm_async_batch_call import (
BatchPromptProcessorTool,
) # Provided async tool
initialize_litellm_cache()
wikipedia.set_lang("en")
project_root = get_project_root()
load_env_file()
tools = [SummarizationTool(), BatchPromptProcessorTool()]
# Build the system prompt WITH placeholders intact
# system_prompt = (
# "You are an expert at processing tasks. Follow these rules:\n"
# "Authorized imports: {{authorized_imports}}\n"
# "{{tool_descriptions}}\n\n"
# )
system_prompt = build_code_agent_system_prompt(tools, additional_imports=["wikipedia"])
authorized_imports = collect_authorized_imports(tools)
# Create the CodeAgent - let it handle the placeholder replacements
agent = CodeAgent(
tools=tools,
model=LiteLLMModel(
model_id="openai/gpt-4o",
api_key=os.getenv("OPENAI_API_KEY"),
),
# system_prompt=system_prompt,
# tool_description_template=(
# "Use the following tools when appropriate: {{tools}}. "
# "Authorized imports: {{authorized_imports}} "
# ),
max_steps=7,
verbosity_level=4,
additional_authorized_imports=["asyncio", "litellm", "wikipedia"],
)
agent.system_prompt = build_code_agent_system_prompt(
tools, additional_imports=["wikipedia"]
)
agent.tool_description_template = (
"Use the following tools when appropriate: {{tools}}. "
"Authorized imports: {{authorized_imports}} "
)
# Example text to summarize.
title = "Artificial Intelligence"
(text_to_summarize, token_count) = load_wikipedia_text(title)
# text_to_summarize = (
# "Artificial Intelligence (AI) is a branch of computer science that aims to create "
# "intelligent machines capable of performing tasks that typically require human intelligence. "
# "These tasks include problem-solving, learning, planning, and language understanding. "
# "Recent advances in machine learning and deep learning have significantly boosted AI research, "
# "leading to breakthroughs in fields such as healthcare, finance, and transportation. "
# "Despite these advances, AI also raises ethical and social challenges that need to be addressed."
# )
user_question = f"Please summarize the text:\n{text_to_summarize}"
# Run the agent synchronously to invoke the summarization tool.
summarization_response = run_agent_with_retry(agent, user_question)
print("Summarization Response:")
print(summarization_response)
from typing import List, Optional, Set
from smolagents import Tool
from loguru import logger
def collect_authorized_imports(
tools: List[Tool], additional_imports: Optional[List[str]] = None
) -> List[str]:
"""
Collects all authorized imports from a list of tools and additional imports.
Args:
tools: List of tools to collect authorized imports from.
additional_imports: Optional list of additional authorized imports.
Returns:
List[str]: A list of all authorized imports.
Raises:
ValueError: If no tools are provided or if a tool is missing required attributes.
"""
# Validate inputs
if not tools:
logger.error("No tools provided to collect authorized imports.")
raise ValueError("At least one tool must be provided.")
authorized_imports: Set[str] = set()
# Collect authorized imports from tools
for tool in tools:
try:
if hasattr(tool, "authorized_imports"):
authorized_imports.update(tool.authorized_imports)
logger.debug(f"Added authorized imports for tool: {tool.name}")
else:
logger.warning(
f"Tool '{tool.name}' does not have an 'authorized_imports' attribute."
)
except Exception as e:
logger.error(f"Error processing tool '{tool.name}': {e}")
raise ValueError(f"Failed to process tool '{tool.name}': {e}")
# Add additional authorized imports (if any)
if additional_imports:
authorized_imports.update(additional_imports)
logger.debug(f"Added additional authorized imports: {additional_imports}")
logger.info("Authorized imports collected successfully.")
return list(authorized_imports)
def build_code_agent_system_prompt(
tools: List[Tool], additional_imports: Optional[List[str]] = None
) -> str:
"""
Builds a system prompt for an agent by aggregating tool-specific prompts.
Ensures correct function execution format and structured outputs.
Leaves placeholders intact for SmolAgents to handle.
Args:
tools: List of tools to include in the system prompt.
additional_imports: Optional list of additional authorized imports.
Returns:
str: The complete system prompt for the agent with placeholders intact.
Raises:
ValueError: If no tools are provided or if a tool is missing required attributes.
"""
if not tools:
logger.error("No tools provided to build the system prompt.")
raise ValueError("At least one tool must be provided.")
tool_prompts = []
for tool in tools:
if hasattr(tool, "system_prompt"):
tool_prompt = f"### {tool.name} ###\n{tool.system_prompt}"
# Check if the tool returns a function or text
if hasattr(tool, "output_type") and tool.output_type == "object":
modified_prompt = (
f"{tool_prompt}\n\n"
"IMPORTANT: DO NOT pass arguments as a dictionary. "
"Always use named parameters.\n"
"Use this exact format:\n"
"```python\n"
"def run():\n"
f" return {tool.name}.forward(text=dynamic_text, config=dynamic_config)\n"
"result = run()\n"
"```\n"
)
else:
modified_prompt = (
f"{tool_prompt}\n\n"
"IMPORTANT: This tool returns text directly. Do NOT wrap it in a function.\n"
"Simply return the text output."
)
tool_prompts.append(modified_prompt)
logger.debug(f"Added system prompt for tool: {tool.name}")
system_prompt = (
"You are an expert at processing tasks. Follow these rules:\n\n" +
"IMPORTANT: If you receive a request to summarize large text, "+
"**DO NOT** attempt to process it yourself.\n" +
"Instead, use the 'summarization_tool' to summarize the text.\n\n" +
"\n\n".join(tool_prompts) + "\n\n" +
"Authorized imports: {{authorized_imports}}\n"
"{{tool_descriptions}}\n"
"{{managed_agents_descriptions}}\n\n"
)
required_placeholders = [
"{{authorized_imports}}",
"{{tool_descriptions}}",
"{{managed_agents_descriptions}}"
]
for placeholder in required_placeholders:
if placeholder not in system_prompt:
logger.error(f"Missing required placeholder: {placeholder}")
raise ValueError(f"System prompt is missing required placeholder: {placeholder}")
logger.info("System prompt built successfully with placeholders intact.")
logger.debug(f"Final system prompt:\n{system_prompt}")
return system_prompt
import pytest
from smolagents import CodeAgent
from smolagents.models import LiteLLMModel
from web_browser.tools.llm_tools.summarization_tool import SummarizationTool
from web_browser.utils.agent_utils import build_code_agent_system_prompt
def test_summarization_tool_integration():
# 1. Setup
long_text = " ".join(["This is a test sentence."] * 1000) # Create text > 1500 tokens
short_text = "This is a short text that doesn't need summarization."
tools = [SummarizationTool()]
agent = CodeAgent(
tools=tools,
model=LiteLLMModel(model_id="openai/gpt-4"),
max_steps=3,
verbosity_level=4,
)
agent.system_prompt = build_code_agent_system_prompt(tools)
# 2. Test direct summarization request
response = agent.run(f"Please summarize this text:\n{long_text}")
assert "summary" in response, "Summary not found in response"
assert "thoughts" in response, "Thoughts not found in response"
# 3. Test multi-step task with embedded long text
multi_step_prompt = f"""
Please follow these steps:
1. Read this document: {long_text}
2. Identify the main topics
3. Create a brief outline
"""
response = agent.run(multi_step_prompt)
assert len(response) < len(long_text), "Text wasn't summarized in multi-step task"
# 4. Test that short text isn't unnecessarily summarized
response = agent.run(f"Please analyze this text:\n{short_text}")
assert len(response) >= len(short_text), "Short text was unnecessarily summarized"
# 5. Test tool selection
chosen_tool = agent.decide_tool(long_text)
assert chosen_tool == "summarization_tool", "Summarization tool not selected for long text"
chosen_tool = agent.decide_tool(short_text)
assert chosen_tool is None, "Tool incorrectly selected for short text"
def test_summarization_tool_output_format():
tool = SummarizationTool()
text = " ".join(["This is a test sentence."] * 100)
result = tool.forward(text=text)
# Verify the output structure matches what SmolAgents expects
assert isinstance(result, dict), "Result should be a dictionary"
assert "summary" in result, "Result should contain 'summary'"
assert "thoughts" in result, "Result should contain 'thoughts'"
assert isinstance(result["summary"], str), "Summary should be a string"
assert isinstance(result["thoughts"], str), "Thoughts should be a string"
def test_summarization_tool_error_handling():
tool = SummarizationTool()
# Test empty text
with pytest.raises(ValueError):
tool.forward(text="")
# Test None input
with pytest.raises(ValueError):
tool.forward(text=None)
# Test very short text
result = tool.forward(text="Short text")
assert result["summary"], "Tool should handle short text gracefully"
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import re
import textwrap
import time
from collections import deque
from logging import getLogger
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
import yaml
from jinja2 import StrictUndefined, Template
from rich.console import Group
from rich.panel import Panel
from rich.rule import Rule
from rich.text import Text
from smolagents.agent_types import AgentAudio, AgentImage, handle_agent_output_types
from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from smolagents.monitoring import (
YELLOW_HEX,
AgentLogger,
LogLevel,
)
from smolagents.utils import (
AgentError,
AgentExecutionError,
AgentGenerationError,
AgentMaxStepsError,
AgentParsingError,
parse_code_blobs,
parse_json_tool_call,
truncate_content,
)
from .agent_types import AgentType
from .default_tools import TOOL_MAPPING, FinalAnswerTool
from .e2b_executor import E2BExecutor
from .local_python_executor import (
BASE_BUILTIN_MODULES,
LocalPythonInterpreter,
fix_final_answer_code,
)
from .models import (
ChatMessage,
MessageRole,
)
from .monitoring import Monitor
from .tools import Tool
logger = getLogger(__name__)
def get_variable_names(self, template: str) -> Set[str]:
pattern = re.compile(r"\{\{([^{}]+)\}\}")
return {match.group(1).strip() for match in pattern.finditer(template)}
def populate_template(template: str, variables: Dict[str, Any]) -> str:
compiled_template = Template(template, undefined=StrictUndefined)
try:
return compiled_template.render(**variables)
except Exception as e:
raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}")
class MultiStepAgent:
"""
Agent class that solves the given task step by step, using the ReAct framework:
While the objective is not reached, the agent will perform a cycle of action (given by the LLM) and observation (obtained from the environment).
Args:
tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary.
max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task.
tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output.
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
verbosity_level (`int`, default `1`): Level of verbosity of the agent's logs.
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
managed_agents (`list`, *optional*): Managed agents that the agent can call.
step_callbacks (`list[Callable]`, *optional*): Callbacks that will be called at each step.
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
name (`str`, *optional*): Necessary for a managed agent only - the name by which this agent can be called.
description (`str`, *optional*): Necessary for a managed agent only - the description of this agent.
provide_run_summary (`bool`, *optional*): Whether to provide a run summary when called as a managed agent.
final_answer_checks (`list`, *optional*): List of Callables to run before returning a final answer for checking validity.
"""
def __init__(
self,
tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage],
prompts_path: Optional[str] = None,
max_steps: int = 6,
tool_parser: Optional[Callable] = None,
add_base_tools: bool = False,
verbosity_level: int = 1,
grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None,
step_callbacks: Optional[List[Callable]] = None,
planning_interval: Optional[int] = None,
name: Optional[str] = None,
description: Optional[str] = None,
provide_run_summary: bool = False,
final_answer_checks: Optional[List[Callable]] = None,
):
if tool_parser is None:
tool_parser = parse_json_tool_call
self.agent_name = self.__class__.__name__
self.model = model
self.max_steps = max_steps
self.step_number: int = 0
self.tool_parser = tool_parser
self.grammar = grammar
self.planning_interval = planning_interval
self.state = {}
self.name = name
self.description = description
self.provide_run_summary = provide_run_summary
self.managed_agents = {}
if managed_agents is not None:
for managed_agent in managed_agents:
assert managed_agent.name and managed_agent.description, (
"All managed agents need both a name and a description!"
)
self.managed_agents = {agent.name: agent for agent in managed_agents}
for tool in tools:
assert isinstance(tool, Tool), f"This element is not of class Tool: {str(tool)}"
self.tools = {tool.name: tool for tool in tools}
if add_base_tools:
for tool_name, tool_class in TOOL_MAPPING.items():
if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent":
self.tools[tool_name] = tool_class()
self.tools["final_answer"] = FinalAnswerTool()
self.system_prompt = self.initialize_system_prompt()
self.input_messages = None
self.task = None
self.memory = AgentMemory(self.system_prompt)
self.logger = AgentLogger(level=verbosity_level)
self.monitor = Monitor(self.model, self.logger)
self.step_callbacks = step_callbacks if step_callbacks is not None else []
self.step_callbacks.append(self.monitor.update_metrics)
self.final_answer_checks = final_answer_checks
@property
def logs(self):
logger.warning(
"The 'logs' attribute is deprecated and will soon be removed. Please use 'self.memory.steps' instead."
)
return [self.memory.system_prompt] + self.memory.steps
def initialize_system_prompt(self):
"""To be implemented in child classes"""
pass
def write_memory_to_messages(
self,
summary_mode: Optional[bool] = False,
) -> List[Dict[str, str]]:
"""
Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
that can be used as input to the LLM. Adds a number of keywords (such as PLAN, error, etc) to help
the LLM.
"""
messages = self.memory.system_prompt.to_messages(summary_mode=summary_mode)
for memory_step in self.memory.steps:
messages.extend(memory_step.to_messages(summary_mode=summary_mode))
return messages
def visualize(self):
"""Creates a rich tree visualization of the agent's structure."""
self.logger.visualize_agent_tree(self)
def extract_action(self, model_output: str, split_token: str) -> Tuple[str, str]:
"""
Parse action from the LLM output
Args:
model_output (`str`): Output of the LLM
split_token (`str`): Separator for the action. Should match the example in the system prompt.
"""
try:
split = model_output.split(split_token)
rationale, action = (
split[-2],
split[-1],
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
except Exception:
raise AgentParsingError(
f"No '{split_token}' token provided in your output.\nYour output:\n{model_output}\n. Be sure to include an action, prefaced with '{split_token}'!",
self.logger,
)
return rationale.strip(), action.strip()
def provide_final_answer(self, task: str, images: Optional[list[str]]) -> str:
"""
Provide the final answer to the task, based on the logs of the agent's interactions.
Args:
task (`str`): Task to perform.
images (`list[str]`, *optional*): Paths to image(s).
Returns:
`str`: Final answer to the task.
"""
messages = [{"role": MessageRole.SYSTEM, "content": []}]
if images:
messages[0]["content"] = [
{
"type": "text",
"text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
}
]
messages[0]["content"].append({"type": "image"})
messages += self.write_memory_to_messages()[1:]
messages += [
{
"role": MessageRole.USER,
"content": [
{
"type": "text",
"text": f"Based on the above, please provide an answer to the following user request:\n{task}",
}
],
}
]
else:
messages[0]["content"] = [
{
"type": "text",
"text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
}
]
messages += self.write_memory_to_messages()[1:]
messages += [
{
"role": MessageRole.USER,
"content": [
{
"type": "text",
"text": f"Based on the above, please provide an answer to the following user request:\n{task}",
}
],
}
]
try:
chat_message: ChatMessage = self.model(messages)
return chat_message.content
except Exception as e:
return f"Error in generating final LLM output:\n{e}"
def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any:
"""
Execute tool with the provided input and returns the result.
This method replaces arguments with the actual values from the state if they refer to state variables.
Args:
tool_name (`str`): Name of the Tool to execute (should be one from self.tools).
arguments (Dict[str, str]): Arguments passed to the Tool.
"""
available_tools = {**self.tools, **self.managed_agents}
if tool_name not in available_tools:
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
raise AgentExecutionError(error_msg, self.logger)
try:
if isinstance(arguments, str):
if tool_name in self.managed_agents:
observation = available_tools[tool_name].__call__(arguments)
else:
observation = available_tools[tool_name].__call__(arguments, sanitize_inputs_outputs=True)
elif isinstance(arguments, dict):
for key, value in arguments.items():
if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value]
if tool_name in self.managed_agents:
observation = available_tools[tool_name].__call__(**arguments)
else:
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
raise AgentExecutionError(error_msg, self.logger)
return observation
except Exception as e:
if tool_name in self.tools:
tool = self.tools[tool_name]
error_msg = (
f"Error whene executing tool {tool_name} with arguments {arguments}: {type(e).__name__}: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following: '{tool.description}'.\nIt takes inputs: {tool.inputs} and returns output type {tool.output_type}"
)
raise AgentExecutionError(error_msg, self.logger)
elif tool_name in self.managed_agents:
error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
)
raise AgentExecutionError(error_msg, self.logger)
def step(self, memory_step: ActionStep) -> Union[None, Any]:
"""To be implemented in children classes. Should return either None if the step is not final."""
pass
def run(
self,
task: str,
stream: bool = False,
reset: bool = True,
images: Optional[List[str]] = None,
additional_args: Optional[Dict] = None,
):
"""
Run the agent for the given task.
Args:
task (`str`): Task to perform.
stream (`bool`): Whether to run in a streaming way.
reset (`bool`): Whether to reset the conversation or keep it going from previous run.
images (`list[str]`, *optional*): Paths to image(s).
additional_args (`dict`): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
Example:
```py
from smolagents import CodeAgent
agent = CodeAgent(tools=[])
agent.run("What is the result of 2 power 3.7384?")
```
"""
self.task = task
if additional_args is not None:
self.state.update(additional_args)
self.task += f"""
You have been provided with these additional arguments, that you can access using the keys as variables in your python code:
{str(additional_args)}."""
self.system_prompt = self.initialize_system_prompt()
self.memory.system_prompt = SystemPromptStep(system_prompt=self.system_prompt)
if reset:
self.memory.reset()
self.monitor.reset()
self.logger.log_task(
content=self.task.strip(),
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
level=LogLevel.INFO,
title=self.name if hasattr(self, "name") else None,
)
self.memory.steps.append(TaskStep(task=self.task, task_images=images))
if stream:
# The steps are returned as they are executed through a generator to iterate on.
return self._run(task=self.task, images=images)
# Outputs are returned only at the end as a string. We only look at the last step
return deque(self._run(task=self.task, images=images), maxlen=1)[0]
def _run(self, task: str, images: List[str] | None = None) -> Generator[ActionStep | AgentType, None, None]:
"""
Run the agent in streaming mode and returns a generator of all the steps.
Args:
task (`str`): Task to perform.
images (`list[str]`): Paths to image(s).
"""
final_answer = None
self.step_number = 1
while final_answer is None and self.step_number <= self.max_steps:
step_start_time = time.time()
memory_step = ActionStep(
step_number=self.step_number,
start_time=step_start_time,
observations_images=images,
)
try:
if self.planning_interval is not None and self.step_number % self.planning_interval == 1:
self.planning_step(
task,
is_first_step=(self.step_number == 1),
step=self.step_number,
)
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
# Run one step!
final_answer = self.step(memory_step)
if final_answer is not None and self.final_answer_checks is not None:
for check_function in self.final_answer_checks:
try:
assert check_function(final_answer, self.memory)
except Exception as e:
final_answer = None
raise AgentError(f"Check {check_function.__name__} failed with error: {e}", self.logger)
except AgentError as e:
memory_step.error = e
finally:
memory_step.end_time = time.time()
memory_step.duration = memory_step.end_time - step_start_time
self.memory.steps.append(memory_step)
for callback in self.step_callbacks:
# For compatibility with old callbacks that don't take the agent as an argument
if len(inspect.signature(callback).parameters) == 1:
callback(memory_step)
else:
callback(memory_step, agent=self)
self.step_number += 1
yield memory_step
if final_answer is None and self.step_number == self.max_steps + 1:
error_message = "Reached max steps."
final_answer = self.provide_final_answer(task, images)
final_memory_step = ActionStep(
step_number=self.step_number, error=AgentMaxStepsError(error_message, self.logger)
)
final_memory_step.action_output = final_answer
final_memory_step.end_time = time.time()
final_memory_step.duration = memory_step.end_time - step_start_time
self.memory.steps.append(final_memory_step)
for callback in self.step_callbacks:
# For compatibility with old callbacks that don't take the agent as an argument
if len(inspect.signature(callback).parameters) == 1:
callback(final_memory_step)
else:
callback(final_memory_step, agent=self)
yield final_memory_step
yield handle_agent_output_types(final_answer)
def planning_step(self, task, is_first_step: bool, step: int) -> None:
"""
Used periodically by the agent to plan the next steps to reach the objective.
Args:
task (`str`): Task to perform.
is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
step (`int`): The number of the current step, used as an indication for the LLM.
"""
if is_first_step:
message_prompt_facts = {
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": self.prompt_templates["planning"]["initial_facts"]}],
}
input_messages = [message_prompt_facts]
chat_message_facts: ChatMessage = self.model(input_messages)
answer_facts = chat_message_facts.content
message_prompt_plan = {
"role": MessageRole.USER,
"content": [
{
"type": "text",
"text": populate_template(
self.prompt_templates["planning"]["initial_plan"],
variables={
"task": task,
"tools": self.tools,
"managed_agents": self.managed_agents,
"answer_facts": answer_facts,
},
),
}
],
}
chat_message_plan: ChatMessage = self.model(
[message_prompt_plan],
stop_sequences=["<end_plan>"],
)
answer_plan = chat_message_plan.content
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
```
{answer_plan}
```"""
final_facts_redaction = f"""Here are the facts that I know so far:
```
{answer_facts}
```""".strip()
self.memory.steps.append(
PlanningStep(
model_input_messages=input_messages,
plan=final_plan_redaction,
facts=final_facts_redaction,
model_output_message_plan=chat_message_plan,
model_output_message_facts=chat_message_facts,
)
)
self.logger.log(
Rule("[bold]Initial plan", style="orange"),
Text(final_plan_redaction),
level=LogLevel.INFO,
)
else: # update plan
memory_messages = self.write_memory_to_messages(
summary_mode=False
) # This will not log the plan but will log facts
# Redact updated facts
facts_update_pre_messages = {
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": self.prompt_templates["planning"]["update_facts_pre_messages"]}],
}
facts_update_post_messages = {
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": self.prompt_templates["planning"]["update_facts_post_messages"]}],
}
input_messages = [facts_update_pre_messages] + memory_messages + [facts_update_post_messages]
chat_message_facts: ChatMessage = self.model(input_messages)
facts_update = chat_message_facts.content
# Redact updated plan
update_plan_pre_messages = {
"role": MessageRole.SYSTEM,
"content": [
{
"type": "text",
"text": populate_template(
self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task}
),
}
],
}
update_plan_post_messages = {
"role": MessageRole.SYSTEM,
"content": [
{
"type": "text",
"text": populate_template(
self.prompt_templates["planning"]["update_plan_pre_messages"],
variables={
"task": task,
"tools": self.tools,
"managed_agents": self.managed_agents,
"facts_update": facts_update,
"remaining_steps": (self.max_steps - step),
},
),
}
],
}
chat_message_plan: ChatMessage = self.model(
[update_plan_pre_messages] + memory_messages + [update_plan_post_messages],
stop_sequences=["<end_plan>"],
)
# Log final facts and plan
final_plan_redaction = textwrap.dedent(
f"""I still need to solve the task I was given:
```
{task}
```
Here is my new/updated plan of action to solve the task:
```
{chat_message_plan.content}
```"""
)
final_facts_redaction = textwrap.dedent(
f"""Here is the updated list of the facts that I know:
```
{facts_update}
```"""
)
self.memory.steps.append(
PlanningStep(
model_input_messages=input_messages,
plan=final_plan_redaction,
facts=final_facts_redaction,
model_output_message_plan=chat_message_plan,
model_output_message_facts=chat_message_facts,
)
)
self.logger.log(
Rule("[bold]Updated plan", style="orange"),
Text(final_plan_redaction),
level=LogLevel.INFO,
)
def replay(self, detailed: bool = False):
"""Prints a pretty replay of the agent's steps.
Args:
detailed (bool, optional): If True, also displays the memory at each step. Defaults to False.
Careful: will increase log length exponentially. Use only for debugging.
"""
self.memory.replay(self.logger, detailed=detailed)
def __call__(self, task: str, **kwargs):
"""
This methd is called only by a manager agent.
Adds additional prompting for the managed agent, runs it, and wraps the output.
"""
full_task = populate_template(
self.prompt_templates["managed_agent"]["task"],
variables=dict(name=self.name, task=task),
)
report = self.run(full_task, **kwargs)
answer = populate_template(
self.prompt_templates["managed_agent"]["report"], variables=dict(name=self.name, final_answer=report)
)
if self.provide_run_summary:
answer += "\n\nFor more detail, find below a summary of this agent's work:\n<summary_of_work>\n"
for message in self.write_memory_to_messages(summary_mode=True):
content = message["content"]
answer += "\n" + truncate_content(str(content)) + "\n---"
answer += "\n</summary_of_work>"
return answer
class ToolCallingAgent(MultiStepAgent):
"""
This agent uses JSON-like tool calls, using method `model.get_tool_call` to leverage the LLM engine's tool calling capabilities.
Args:
tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary.
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage],
prompts_path: Optional[str] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
yaml_path = os.path.join(os.path.dirname(__file__), "prompts", "toolcalling_agent.yaml")
with open(yaml_path, "r") as f:
self.prompt_templates = yaml.safe_load(f)
super().__init__(
tools=tools,
model=model,
prompts_path=prompts_path,
planning_interval=planning_interval,
**kwargs,
)
def initialize_system_prompt(self) -> str:
system_prompt = populate_template(
self.prompt_templates["system_prompt"],
variables={"tools": self.tools, "managed_agents": self.managed_agents},
)
return system_prompt
def step(self, memory_step: ActionStep) -> Union[None, Any]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Returns None if the step is not final.
"""
memory_messages = self.write_memory_to_messages()
self.input_messages = memory_messages
# Add new step in logs
memory_step.model_input_messages = memory_messages.copy()
try:
model_message: ChatMessage = self.model(
memory_messages,
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)
memory_step.model_output_message = model_message
if model_message.tool_calls is None or len(model_message.tool_calls) == 0:
raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.")
tool_call = model_message.tool_calls[0]
tool_name, tool_call_id = tool_call.function.name, tool_call.id
tool_arguments = tool_call.function.arguments
except Exception as e:
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) from e
memory_step.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
# Execute
self.logger.log(
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")),
level=LogLevel.INFO,
)
if tool_name == "final_answer":
if isinstance(tool_arguments, dict):
if "answer" in tool_arguments:
answer = tool_arguments["answer"]
else:
answer = tool_arguments
else:
answer = tool_arguments
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
final_answer = self.state[answer]
self.logger.log(
f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'.",
level=LogLevel.INFO,
)
else:
final_answer = answer
self.logger.log(
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"),
level=LogLevel.INFO,
)
memory_step.action_output = final_answer
return final_answer
else:
if tool_arguments is None:
tool_arguments = {}
observation = self.execute_tool_call(tool_name, tool_arguments)
observation_type = type(observation)
if observation_type in [AgentImage, AgentAudio]:
if observation_type == AgentImage:
observation_name = "image.png"
elif observation_type == AgentAudio:
observation_name = "audio.mp3"
# TODO: observation naming could allow for different names of same type
self.state[observation_name] = observation
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
self.logger.log(
f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components
level=LogLevel.INFO,
)
memory_step.observations = updated_information
return None
class CodeAgent(MultiStepAgent):
"""
In this agent, the tool calls will be formulated by the LLM in code format, then parsed and executed.
Args:
tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary.
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
use_e2b_executor (`bool`, default `False`): Whether to use the E2B executor for remote code execution.
max_print_outputs_length (`int`, *optional*): Maximum length of the print outputs.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage],
prompts_path: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
use_e2b_executor: bool = False,
max_print_outputs_length: Optional[int] = None,
**kwargs,
):
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
yaml_path = os.path.join(os.path.dirname(__file__), "prompts", "code_agent.yaml")
with open(yaml_path, "r") as f:
self.prompt_templates = yaml.safe_load(f)
super().__init__(
tools=tools,
model=model,
grammar=grammar,
planning_interval=planning_interval,
**kwargs,
)
if "*" in self.additional_authorized_imports:
self.logger.log(
"Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
0,
)
if use_e2b_executor and len(self.managed_agents) > 0:
raise Exception(
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
)
all_tools = {**self.tools, **self.managed_agents}
if use_e2b_executor:
self.python_executor = E2BExecutor(
self.additional_authorized_imports,
list(all_tools.values()),
self.logger,
)
else:
self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports,
all_tools,
max_print_outputs_length=max_print_outputs_length,
)
def initialize_system_prompt(self) -> str:
system_prompt = populate_template(
self.prompt_templates["system_prompt"],
variables={
"tools": self.tools,
"managed_agents": self.managed_agents,
"authorized_imports": (
"You can import from any package you want."
if "*" in self.authorized_imports
else str(self.authorized_imports)
),
},
)
return system_prompt
def decide_tool(self, task_text: str) -> Optional[str]:
"""
Decide if one of the tools can completely answer the question.
This function first checks if the text is very long (using a token count) and then,
if needed, queries the LLM with a meta-prompt listing the available tools.
Returns:
The key of the tool to use (e.g. "summarization_tool"), or None if no tool should be used.
"""
# Example: if text is longer than a threshold, choose the summarization tool.
from web_browser.tools.llm_tools.utils.estimate_token_count import (
estimate_token_count,
)
TOKEN_THRESHOLD = 1500 # adjust threshold as needed
if estimate_token_count(task_text) > TOKEN_THRESHOLD:
# If the text is very long, it makes sense to use the summarization tool.
return "summarization_tool"
# Otherwise, prepare a meta prompt for a lightweight decision:
available_tool_names = list(self.tools.keys())
meta_prompt = (
"Given the following task:\n"
f"{task_text}\n\n"
"and these available tools: " + ", ".join(available_tool_names) + "\n\n"
"Which tool, if any, can completely answer this question? "
"Return only the tool name (or 'none' if no tool applies)."
)
# Use a lightweight model call (or even a pre-canned heuristic) for a quick answer.
meta_response = self.model([{"role": "user", "content": meta_prompt}])
chosen_tool = meta_response.content.strip().lower()
if chosen_tool in self.tools:
return chosen_tool
return None
def step(self, memory_step: ActionStep) -> Union[None, Any]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Returns None if the step is not final.
"""
# --- Meta Decision Step ---
chosen_tool = self.decide_tool(self.task)
if chosen_tool is not None:
tool = self.tools.get(chosen_tool)
if tool is None:
raise AgentExecutionError(
f"Tool '{chosen_tool}' was indicated but is not available.",
self.logger,
)
self.logger.log(
f"Meta-decision: Using tool '{chosen_tool}' to handle the task.",
level=LogLevel.INFO,
)
# Call the tool directly. Here we assume the task text is to be used as the input.
final_answer = tool.forward(text=self.task, config=None)
memory_step.action_output = final_answer
return final_answer
# --- End Meta Decision Step ---
# Otherwise, continue with the normal chain-of-thought reasoning.
memory_messages = self.write_memory_to_messages()
self.input_messages = memory_messages.copy()
# Add new step in logs
memory_step.model_input_messages = memory_messages.copy()
try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
chat_message: ChatMessage = self.model(
self.input_messages,
stop_sequences=["<end_code>", "Observation:"],
**additional_args,
)
memory_step.model_output_message = chat_message
model_output = chat_message.content
memory_step.model_output = model_output
except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e
self.logger.log_markdown(
content=model_output,
title="Output message of the LLM:",
level=LogLevel.DEBUG,
)
# Parse
try:
code_action = fix_final_answer_code(parse_code_blobs(model_output))
except Exception as e:
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
raise AgentParsingError(error_msg, self.logger)
memory_step.tool_calls = [
ToolCall(
name="python_interpreter",
arguments=code_action,
id=f"call_{len(self.memory.steps)}",
)
]
# Execute
self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
is_final_answer = False
try:
output, execution_logs, is_final_answer = self.python_executor(
code_action,
self.state,
)
execution_outputs_console = []
if len(execution_logs) > 0:
execution_outputs_console += [
Text("Execution logs:", style="bold"),
Text(execution_logs),
]
observation = "Execution logs:\n" + execution_logs
except Exception as e:
if hasattr(self.python_executor, "state") and "_print_outputs" in self.python_executor.state:
execution_logs = str(self.python_executor.state["_print_outputs"])
if len(execution_logs) > 0:
execution_outputs_console = [
Text("Execution logs:", style="bold"),
Text(execution_logs),
]
memory_step.observations = "Execution logs:\n" + execution_logs
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
error_msg = str(e)
if "Import of " in error_msg and " is not allowed" in error_msg:
self.logger.log(
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
level=LogLevel.INFO,
)
raise AgentExecutionError(error_msg, self.logger)
truncated_output = truncate_content(str(output))
observation += "Last output from code snippet:\n" + truncated_output
memory_step.observations = observation
execution_outputs_console += [
Text(
f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}",
style=(f"bold {YELLOW_HEX}" if is_final_answer else ""),
),
]
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
memory_step.action_output = output
return output if is_final_answer else None
__all__ = ["MultiStepAgent", "CodeAgent", "ToolCallingAgent", "AgentMemory"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment