Skip to content

Instantly share code, notes, and snippets.

@oneryalcin
Created January 23, 2026 23:31
Show Gist options
  • Select an option

  • Save oneryalcin/70464f35727a24ab8eb23fdb9ff471ad to your computer and use it in GitHub Desktop.

Select an option

Save oneryalcin/70464f35727a24ab8eb23fdb9ff471ad to your computer and use it in GitHub Desktop.
RLM - Recursive Language models - Smolagent Implementation
#!/usr/bin/env python3
"""
RLM v2 (Recursive Language Model) Module for Smolagents
Implements DSPy RLM-style strategies for handling large contexts:
- Peeking: Look at data structure before processing
- Grepping: Use string/regex matching before LLM calls
- Partition + Map: Chunk data and process with batched sub-LLM calls
- Summarization: Hierarchical compression for understanding the whole
- Hybrid: Combine strategies as needed
Features:
- Shared call counter between llm_query and llm_query_batched
- Step callbacks for tracking/debugging
- Variable metadata display
- RLM strategy guidance in prompts
Usage:
from rlm_v2 import RLMAgent, create_rlm_agent
agent = create_rlm_agent("gpt-4.1-nano")
result = agent.run(
task="Classify each review as positive/negative",
context=large_text
)
"""
from smolagents import CodeAgent, Tool
from smolagents.agents import ActionStep
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from typing import Any, Callable
# =============================================================================
# Shared Call Counter
# =============================================================================
class LLMCallCounter:
"""Thread-safe counter shared between llm_query tools."""
def __init__(self, max_calls: int = 50):
self.max_calls = max_calls
self._count = 0
self._lock = threading.Lock()
def check_and_increment(self, n: int = 1) -> int:
"""Check if we can make n calls, increment counter, return new count."""
with self._lock:
if self._count + n > self.max_calls:
raise RuntimeError(
f"LLM call limit exceeded: {self._count} + {n} > {self.max_calls}. "
"Use Python for aggregation instead of more LLM calls."
)
self._count += n
return self._count
def reset(self):
with self._lock:
self._count = 0
@property
def remaining(self) -> int:
with self._lock:
return self.max_calls - self._count
@property
def count(self) -> int:
with self._lock:
return self._count
# =============================================================================
# RLM Tools with Shared Counter
# =============================================================================
class LLMQueryTool(Tool):
"""Query a sub-LLM for semantic analysis."""
name = "llm_query"
description = (
"Query a language model for semantic analysis (classification, summarization, understanding). "
"Use Python string matching for pattern/location tasks - it's free and faster."
)
inputs = {
"prompt": {"type": "string", "description": "The prompt to send to the LLM"}
}
output_type = "string"
def __init__(self, model=None, counter: LLMCallCounter = None, **kwargs):
super().__init__(**kwargs)
self.model = model
self.counter = counter or LLMCallCounter()
def forward(self, prompt: str) -> str:
self.counter.check_and_increment(1)
if not self.model:
return "[ERROR] No model configured"
try:
from smolagents.models import ChatMessage, MessageRole
messages = [ChatMessage(role=MessageRole.USER, content=prompt)]
response = self.model.generate(messages)
return response.content
except Exception as e:
return f"[ERROR] {e}"
class LLMQueryBatchedTool(Tool):
"""Query sub-LLM with multiple prompts in parallel."""
name = "llm_query_batched"
description = (
"Query LLM with multiple prompts in PARALLEL (8x faster than sequential). "
"Use for batch classification/analysis. Returns list of responses."
)
inputs = {
"prompts": {"type": "array", "description": "List of prompts"}
}
output_type = "array"
def __init__(
self,
model=None,
counter: LLMCallCounter = None,
max_workers: int = 8,
thread_safe: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.model = model
self.counter = counter or LLMCallCounter()
# If model is not thread-safe, force serial execution.
self.max_workers = max_workers if thread_safe else 1
def forward(self, prompts: list) -> list:
if not prompts:
return []
n = len(prompts)
self.counter.check_and_increment(n)
if not self.model:
return ["[ERROR] No model configured"] * n
from smolagents.models import ChatMessage, MessageRole
def query_one(prompt: str) -> str:
try:
messages = [ChatMessage(role=MessageRole.USER, content=prompt)]
return self.model.generate(messages).content
except Exception as e:
return f"[ERROR] {e}"
results = {}
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {executor.submit(query_one, p): i for i, p in enumerate(prompts)}
for future in as_completed(futures):
results[futures[future]] = future.result()
return [results[i] for i in range(n)]
# =============================================================================
# Variable Metadata
# =============================================================================
def make_variable_info(name: str, value: Any, preview_chars: int = 1000) -> str:
"""
Create RLM-style variable metadata showing type, size, and preview.
This helps the LLM understand the data without seeing all of it,
encouraging strategic exploration (peeking) before processing.
"""
if isinstance(value, str):
total_len = len(value)
lines = value.split('\n')
preview = value[:preview_chars]
truncated = len(value) > preview_chars
return f"""Variable: `{name}` (access it in your code)
Type: str
Total length: {total_len:,} characters
Total lines: {len(lines):,}
Preview (first {preview_chars} chars):
```
{preview}{'...' if truncated else ''}
```"""
elif isinstance(value, (list, tuple)):
type_name = type(value).__name__
total = len(value)
preview = str(value[:10]) if total > 10 else str(value)
return f"""Variable: `{name}` (access it in your code)
Type: {type_name}
Total items: {total:,}
Preview (first 10): {preview}"""
elif isinstance(value, dict):
keys = list(value.keys())
preview_keys = keys[:10]
return f"""Variable: `{name}` (access it in your code)
Type: dict
Total keys: {len(keys):,}
Keys preview: {preview_keys}"""
else:
str_val = str(value)[:preview_chars]
return f"""Variable: `{name}` (access it in your code)
Type: {type(value).__name__}
Value: {str_val}"""
# Alias for compatibility
make_var_info = make_variable_info
# =============================================================================
# OOLONG Prompt Parsing
# =============================================================================
def _strip_chat_markers(text: str) -> str:
"""Remove common chat markers used by some datasets."""
if text.startswith("<|im_start|>user\n"):
text = text[len("<|im_start|>user\n") :]
for suffix in ("<|im_start|>assistant\n", "<|im_end|>\n", "<|im_end|>"):
if text.endswith(suffix):
text = text[: -len(suffix)]
return text
def parse_oolong_prompt(prompt: str) -> tuple[str, str, str]:
"""
Parse an OOLONG prompt into (header, context, question).
The prompt format is typically:
- An instruction header
- Many lines of data starting with 'Date:'
- A trailing question that starts with 'In the above data'
"""
text = _strip_chat_markers(prompt)
marker = "\n\nIn the above data"
idx = text.rfind(marker)
if idx == -1:
raise ValueError("Could not find question marker in prompt.")
body = text[:idx].rstrip()
question = text[idx + 2 :].strip()
lines = body.splitlines()
data_start = None
for i, line in enumerate(lines):
if line.startswith("Date:"):
data_start = i
break
if data_start is None:
header = ""
context = body.strip()
else:
header = "\n".join(lines[:data_start]).strip()
context = "\n".join(lines[data_start:]).strip()
return header, context, question
# =============================================================================
# Step Callback for Debugging
# =============================================================================
def make_rlm_step_callback(counter: LLMCallCounter, verbose: bool = True) -> Callable:
"""Create a step callback that logs RLM metrics."""
def callback(memory_step: ActionStep, agent: "RLMAgent") -> None:
if not verbose:
return
step_num = memory_step.step_number
calls_used = counter.count
calls_remaining = counter.remaining
# Log to observations
rlm_info = f"\n[RLM] Step {step_num} | LLM calls: {calls_used}/{counter.max_calls} used, {calls_remaining} remaining"
if memory_step.observations:
memory_step.observations += rlm_info
else:
memory_step.observations = rlm_info
return callback
# =============================================================================
# RLM Prompt Addition
# =============================================================================
RLM_INSTRUCTIONS = """
## RLM Strategies for Large Context
You have access to `llm_query(prompt)` and `llm_query_batched(prompts)` for semantic analysis.
You have {max_llm_calls} sub-LLM calls available. Use them wisely!
### Strategy 0: READ THE METADATA FIRST
You will receive `variables_info` in the task. Read it before doing anything else.
It tells you type/size/preview so you can decide how to explore the data.
### Strategy 1: PEEK FIRST (Always do this!)
```python
print(f"Total length: {{len(context):,}} chars")
print(f"Lines: {{len(context.split(chr(10))):,}}")
print(context[:2000]) # See the format before processing
```
### Strategy 2: GREP WHEN POSSIBLE
String/regex matching is FREE and FAST. Use before LLM calls:
```python
import re
matches = [l for l in context.split('\\n') if 'keyword' in l]
# Or with regex: matches = re.findall(r'pattern', context)
```
### Strategy 3: PARTITION + MAP (for semantic tasks)
When you NEED semantic understanding, chunk and batch:
```python
lines = context.split('\\n')
chunks = [lines[i:i+50] for i in range(0, len(lines), 50)]
prompts = [f"Classify these:\\n{{chr(10).join(c)}}" for c in chunks]
results = llm_query_batched(prompts) # Parallel! 8x faster than loop
```
### Strategy 4: VERIFY BEFORE final_answer
```python
print(f"Found {{count}} matches")
print(matches[:5]) # Sanity check samples
# Then: final_answer(count)
```
### Strategy 5: SUMMARIZE HIERARCHICALLY
When the data is too big, summarize chunks, then summarize summaries.
```python
chunks = [lines[i:i+200] for i in range(0, len(lines), 200)]
prompts = [f"Summarize:\\n{{chr(10).join(c)}}" for c in chunks]
summaries = llm_query_batched(prompts)
final = llm_query("Combine these summaries into 5 bullet points:\\n" + "\\n".join(summaries))
```
### Strategy 6: HYBRID
Mix grep + summarize: use regex to reduce, then summarize the reduced set.
### DON'T WASTE LLM CALLS ON:
- Counting (use Python: `len()`, `sum()`)
- Pattern matching (use Python: `in`, `re.search()`)
- Filtering (use Python: list comprehensions)
"""
# =============================================================================
# RLMAgent Class
# =============================================================================
class RLMAgent(CodeAgent):
"""
Recursive Language Model Agent.
Extends CodeAgent with:
- llm_query() / llm_query_batched() for sub-LLM calls
- Shared call counter
- RLM strategy guidance
- Step callbacks for debugging
"""
def __init__(
self,
model,
sub_model=None,
max_llm_calls: int = 50,
max_steps: int = 20,
tools: list = None,
verbosity_level: int = 1,
track_calls: bool = True,
batched_thread_safe: bool = False,
**kwargs
):
self.sub_model = sub_model or model
self.max_llm_calls = max_llm_calls
# Shared counter
self._counter = LLMCallCounter(max_calls=max_llm_calls)
# Create tools with shared counter
self._llm_query = LLMQueryTool(model=self.sub_model, counter=self._counter)
self._llm_query_batched = LLMQueryBatchedTool(
model=self.sub_model,
counter=self._counter,
thread_safe=batched_thread_safe,
)
all_tools = [self._llm_query, self._llm_query_batched]
if tools:
all_tools.extend(tools)
# RLM instructions
base_instructions = kwargs.pop('instructions', '') or ''
rlm_instructions = base_instructions + RLM_INSTRUCTIONS.format(max_llm_calls=max_llm_calls)
# Step callbacks
user_callbacks = kwargs.pop('step_callbacks', {})
if track_calls:
rlm_callback = make_rlm_step_callback(self._counter, verbose=(verbosity_level >= 2))
if isinstance(user_callbacks, dict):
existing = user_callbacks.get(ActionStep, [])
if not isinstance(existing, list):
existing = [existing]
user_callbacks[ActionStep] = existing + [rlm_callback]
elif isinstance(user_callbacks, list):
user_callbacks = user_callbacks + [rlm_callback]
super().__init__(
model=model,
tools=all_tools,
max_steps=max_steps,
verbosity_level=verbosity_level,
instructions=rlm_instructions,
step_callbacks=user_callbacks,
**kwargs
)
def run(
self,
task: str,
context: str = None,
reset: bool = True,
show_metadata: bool = True,
**kwargs
):
"""
Run agent on task with optional large context.
Args:
task: The task/question
context: Large text (passed as variable, not stuffed in prompt)
reset: Reset call counter (default True)
show_metadata: Print variable metadata (default True)
**kwargs: Additional variables via additional_args
"""
if reset:
self._counter.reset()
additional_args = kwargs.pop("additional_args", {}) or {}
variables_info: list[str] = []
def _record_var(name: str, value: Any, preview_chars: int = 400) -> None:
info = make_variable_info(name, value, preview_chars=preview_chars)
variables_info.append(info)
if show_metadata:
print(f"\n{'='*60}\nVariable Metadata:\n{'='*60}\n{info}\n")
# Keep large data out of the prompt. Store it in state, show metadata instead.
if context is not None:
self.state["context"] = context
_record_var("context", context)
# Route other large args into state + metadata, keep small args in prompt.
for k, v in kwargs.items():
if isinstance(v, str) and len(v) > 1000:
self.state[k] = v
_record_var(k, v, preview_chars=200)
else:
additional_args[k] = v
if variables_info:
additional_args["variables_info"] = "\n\n".join(variables_info)
return super().run(task=task, additional_args=additional_args, reset=reset)
@property
def llm_calls_used(self) -> int:
return self._counter.count
@property
def llm_calls_remaining(self) -> int:
return self._counter.remaining
# =============================================================================
# Convenience Function
# =============================================================================
def create_rlm_agent(
model_id: str = "gpt-4.1-nano",
sub_model_id: str = None,
max_llm_calls: int = 50,
max_steps: int = 20,
**kwargs
) -> RLMAgent:
"""
Convenience function to create an RLMAgent with LiteLLM.
Args:
model_id: Main model ID (default: gpt-4.1-nano)
sub_model_id: Sub-model for llm_query (default: same as model_id)
max_llm_calls: Max sub-LLM calls (default: 50)
max_steps: Max iterations (default: 20)
Returns:
Configured RLMAgent
Example:
agent = create_rlm_agent("gpt-4.1", sub_model_id="gpt-4.1-nano")
result = agent.run("Summarize this", context=text)
"""
from smolagents import LiteLLMModel
model = LiteLLMModel(model_id=model_id)
sub_model = LiteLLMModel(model_id=sub_model_id) if sub_model_id else None
return RLMAgent(
model=model,
sub_model=sub_model,
max_llm_calls=max_llm_calls,
max_steps=max_steps,
**kwargs
)
# =============================================================================
# Demo
# =============================================================================
if __name__ == "__main__":
print("RLM v2 Module for Smolagents")
print("=" * 60)
print("""
This module provides RLM (Recursive Language Model) capabilities:
1. llm_query(prompt) - Single sub-LLM call for semantic analysis
2. llm_query_batched(prompts) - Parallel sub-LLM calls (8x faster)
3. RLM strategies guidance in system prompt
4. Variable metadata for large contexts
5. Shared call counter with step tracking
Usage:
from rlm_v2 import RLMAgent, create_rlm_agent
# Option 1: Create with model object
from smolagents import LiteLLMModel
model = LiteLLMModel(model_id="gpt-4.1-nano")
agent = RLMAgent(model=model)
# Option 2: Convenience function
agent = create_rlm_agent("gpt-4.1-nano")
# Option 3: With cheaper sub-model for llm_query
agent = create_rlm_agent("gpt-4.1", sub_model_id="gpt-4.1-nano")
# Run on large context
result = agent.run(
task="How many entries are about 'Python'?",
context=large_text
)
print(f"Calls used: {agent.llm_calls_used}")
print(f"Calls remaining: {agent.llm_calls_remaining}")
Strategies (automatically included in prompt):
- PEEK: Always explore data structure first
- GREP: Use string/regex matching before LLM calls
- PARTITION + MAP: Chunk and batch for semantic tasks
- VERIFY: Check results before final_answer
""")
# Quick test of shared counter
print("Testing shared counter...")
counter = LLMCallCounter(max_calls=10)
print(f" Initial: remaining={counter.remaining}, count={counter.count}")
counter.check_and_increment(3)
print(f" After +3: remaining={counter.remaining}, count={counter.count}")
counter.reset()
print(f" After reset: remaining={counter.remaining}, count={counter.count}")
print("\nModule ready!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment