Created
January 23, 2026 23:31
-
-
Save oneryalcin/70464f35727a24ab8eb23fdb9ff471ad to your computer and use it in GitHub Desktop.
RLM - Recursive Language models - Smolagent Implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| RLM v2 (Recursive Language Model) Module for Smolagents | |
| Implements DSPy RLM-style strategies for handling large contexts: | |
| - Peeking: Look at data structure before processing | |
| - Grepping: Use string/regex matching before LLM calls | |
| - Partition + Map: Chunk data and process with batched sub-LLM calls | |
| - Summarization: Hierarchical compression for understanding the whole | |
| - Hybrid: Combine strategies as needed | |
| Features: | |
| - Shared call counter between llm_query and llm_query_batched | |
| - Step callbacks for tracking/debugging | |
| - Variable metadata display | |
| - RLM strategy guidance in prompts | |
| Usage: | |
| from rlm_v2 import RLMAgent, create_rlm_agent | |
| agent = create_rlm_agent("gpt-4.1-nano") | |
| result = agent.run( | |
| task="Classify each review as positive/negative", | |
| context=large_text | |
| ) | |
| """ | |
| from smolagents import CodeAgent, Tool | |
| from smolagents.agents import ActionStep | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import threading | |
| from typing import Any, Callable | |
| # ============================================================================= | |
| # Shared Call Counter | |
| # ============================================================================= | |
| class LLMCallCounter: | |
| """Thread-safe counter shared between llm_query tools.""" | |
| def __init__(self, max_calls: int = 50): | |
| self.max_calls = max_calls | |
| self._count = 0 | |
| self._lock = threading.Lock() | |
| def check_and_increment(self, n: int = 1) -> int: | |
| """Check if we can make n calls, increment counter, return new count.""" | |
| with self._lock: | |
| if self._count + n > self.max_calls: | |
| raise RuntimeError( | |
| f"LLM call limit exceeded: {self._count} + {n} > {self.max_calls}. " | |
| "Use Python for aggregation instead of more LLM calls." | |
| ) | |
| self._count += n | |
| return self._count | |
| def reset(self): | |
| with self._lock: | |
| self._count = 0 | |
| @property | |
| def remaining(self) -> int: | |
| with self._lock: | |
| return self.max_calls - self._count | |
| @property | |
| def count(self) -> int: | |
| with self._lock: | |
| return self._count | |
| # ============================================================================= | |
| # RLM Tools with Shared Counter | |
| # ============================================================================= | |
| class LLMQueryTool(Tool): | |
| """Query a sub-LLM for semantic analysis.""" | |
| name = "llm_query" | |
| description = ( | |
| "Query a language model for semantic analysis (classification, summarization, understanding). " | |
| "Use Python string matching for pattern/location tasks - it's free and faster." | |
| ) | |
| inputs = { | |
| "prompt": {"type": "string", "description": "The prompt to send to the LLM"} | |
| } | |
| output_type = "string" | |
| def __init__(self, model=None, counter: LLMCallCounter = None, **kwargs): | |
| super().__init__(**kwargs) | |
| self.model = model | |
| self.counter = counter or LLMCallCounter() | |
| def forward(self, prompt: str) -> str: | |
| self.counter.check_and_increment(1) | |
| if not self.model: | |
| return "[ERROR] No model configured" | |
| try: | |
| from smolagents.models import ChatMessage, MessageRole | |
| messages = [ChatMessage(role=MessageRole.USER, content=prompt)] | |
| response = self.model.generate(messages) | |
| return response.content | |
| except Exception as e: | |
| return f"[ERROR] {e}" | |
| class LLMQueryBatchedTool(Tool): | |
| """Query sub-LLM with multiple prompts in parallel.""" | |
| name = "llm_query_batched" | |
| description = ( | |
| "Query LLM with multiple prompts in PARALLEL (8x faster than sequential). " | |
| "Use for batch classification/analysis. Returns list of responses." | |
| ) | |
| inputs = { | |
| "prompts": {"type": "array", "description": "List of prompts"} | |
| } | |
| output_type = "array" | |
| def __init__( | |
| self, | |
| model=None, | |
| counter: LLMCallCounter = None, | |
| max_workers: int = 8, | |
| thread_safe: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.model = model | |
| self.counter = counter or LLMCallCounter() | |
| # If model is not thread-safe, force serial execution. | |
| self.max_workers = max_workers if thread_safe else 1 | |
| def forward(self, prompts: list) -> list: | |
| if not prompts: | |
| return [] | |
| n = len(prompts) | |
| self.counter.check_and_increment(n) | |
| if not self.model: | |
| return ["[ERROR] No model configured"] * n | |
| from smolagents.models import ChatMessage, MessageRole | |
| def query_one(prompt: str) -> str: | |
| try: | |
| messages = [ChatMessage(role=MessageRole.USER, content=prompt)] | |
| return self.model.generate(messages).content | |
| except Exception as e: | |
| return f"[ERROR] {e}" | |
| results = {} | |
| with ThreadPoolExecutor(max_workers=self.max_workers) as executor: | |
| futures = {executor.submit(query_one, p): i for i, p in enumerate(prompts)} | |
| for future in as_completed(futures): | |
| results[futures[future]] = future.result() | |
| return [results[i] for i in range(n)] | |
| # ============================================================================= | |
| # Variable Metadata | |
| # ============================================================================= | |
| def make_variable_info(name: str, value: Any, preview_chars: int = 1000) -> str: | |
| """ | |
| Create RLM-style variable metadata showing type, size, and preview. | |
| This helps the LLM understand the data without seeing all of it, | |
| encouraging strategic exploration (peeking) before processing. | |
| """ | |
| if isinstance(value, str): | |
| total_len = len(value) | |
| lines = value.split('\n') | |
| preview = value[:preview_chars] | |
| truncated = len(value) > preview_chars | |
| return f"""Variable: `{name}` (access it in your code) | |
| Type: str | |
| Total length: {total_len:,} characters | |
| Total lines: {len(lines):,} | |
| Preview (first {preview_chars} chars): | |
| ``` | |
| {preview}{'...' if truncated else ''} | |
| ```""" | |
| elif isinstance(value, (list, tuple)): | |
| type_name = type(value).__name__ | |
| total = len(value) | |
| preview = str(value[:10]) if total > 10 else str(value) | |
| return f"""Variable: `{name}` (access it in your code) | |
| Type: {type_name} | |
| Total items: {total:,} | |
| Preview (first 10): {preview}""" | |
| elif isinstance(value, dict): | |
| keys = list(value.keys()) | |
| preview_keys = keys[:10] | |
| return f"""Variable: `{name}` (access it in your code) | |
| Type: dict | |
| Total keys: {len(keys):,} | |
| Keys preview: {preview_keys}""" | |
| else: | |
| str_val = str(value)[:preview_chars] | |
| return f"""Variable: `{name}` (access it in your code) | |
| Type: {type(value).__name__} | |
| Value: {str_val}""" | |
| # Alias for compatibility | |
| make_var_info = make_variable_info | |
| # ============================================================================= | |
| # OOLONG Prompt Parsing | |
| # ============================================================================= | |
| def _strip_chat_markers(text: str) -> str: | |
| """Remove common chat markers used by some datasets.""" | |
| if text.startswith("<|im_start|>user\n"): | |
| text = text[len("<|im_start|>user\n") :] | |
| for suffix in ("<|im_start|>assistant\n", "<|im_end|>\n", "<|im_end|>"): | |
| if text.endswith(suffix): | |
| text = text[: -len(suffix)] | |
| return text | |
| def parse_oolong_prompt(prompt: str) -> tuple[str, str, str]: | |
| """ | |
| Parse an OOLONG prompt into (header, context, question). | |
| The prompt format is typically: | |
| - An instruction header | |
| - Many lines of data starting with 'Date:' | |
| - A trailing question that starts with 'In the above data' | |
| """ | |
| text = _strip_chat_markers(prompt) | |
| marker = "\n\nIn the above data" | |
| idx = text.rfind(marker) | |
| if idx == -1: | |
| raise ValueError("Could not find question marker in prompt.") | |
| body = text[:idx].rstrip() | |
| question = text[idx + 2 :].strip() | |
| lines = body.splitlines() | |
| data_start = None | |
| for i, line in enumerate(lines): | |
| if line.startswith("Date:"): | |
| data_start = i | |
| break | |
| if data_start is None: | |
| header = "" | |
| context = body.strip() | |
| else: | |
| header = "\n".join(lines[:data_start]).strip() | |
| context = "\n".join(lines[data_start:]).strip() | |
| return header, context, question | |
| # ============================================================================= | |
| # Step Callback for Debugging | |
| # ============================================================================= | |
| def make_rlm_step_callback(counter: LLMCallCounter, verbose: bool = True) -> Callable: | |
| """Create a step callback that logs RLM metrics.""" | |
| def callback(memory_step: ActionStep, agent: "RLMAgent") -> None: | |
| if not verbose: | |
| return | |
| step_num = memory_step.step_number | |
| calls_used = counter.count | |
| calls_remaining = counter.remaining | |
| # Log to observations | |
| rlm_info = f"\n[RLM] Step {step_num} | LLM calls: {calls_used}/{counter.max_calls} used, {calls_remaining} remaining" | |
| if memory_step.observations: | |
| memory_step.observations += rlm_info | |
| else: | |
| memory_step.observations = rlm_info | |
| return callback | |
| # ============================================================================= | |
| # RLM Prompt Addition | |
| # ============================================================================= | |
| RLM_INSTRUCTIONS = """ | |
| ## RLM Strategies for Large Context | |
| You have access to `llm_query(prompt)` and `llm_query_batched(prompts)` for semantic analysis. | |
| You have {max_llm_calls} sub-LLM calls available. Use them wisely! | |
| ### Strategy 0: READ THE METADATA FIRST | |
| You will receive `variables_info` in the task. Read it before doing anything else. | |
| It tells you type/size/preview so you can decide how to explore the data. | |
| ### Strategy 1: PEEK FIRST (Always do this!) | |
| ```python | |
| print(f"Total length: {{len(context):,}} chars") | |
| print(f"Lines: {{len(context.split(chr(10))):,}}") | |
| print(context[:2000]) # See the format before processing | |
| ``` | |
| ### Strategy 2: GREP WHEN POSSIBLE | |
| String/regex matching is FREE and FAST. Use before LLM calls: | |
| ```python | |
| import re | |
| matches = [l for l in context.split('\\n') if 'keyword' in l] | |
| # Or with regex: matches = re.findall(r'pattern', context) | |
| ``` | |
| ### Strategy 3: PARTITION + MAP (for semantic tasks) | |
| When you NEED semantic understanding, chunk and batch: | |
| ```python | |
| lines = context.split('\\n') | |
| chunks = [lines[i:i+50] for i in range(0, len(lines), 50)] | |
| prompts = [f"Classify these:\\n{{chr(10).join(c)}}" for c in chunks] | |
| results = llm_query_batched(prompts) # Parallel! 8x faster than loop | |
| ``` | |
| ### Strategy 4: VERIFY BEFORE final_answer | |
| ```python | |
| print(f"Found {{count}} matches") | |
| print(matches[:5]) # Sanity check samples | |
| # Then: final_answer(count) | |
| ``` | |
| ### Strategy 5: SUMMARIZE HIERARCHICALLY | |
| When the data is too big, summarize chunks, then summarize summaries. | |
| ```python | |
| chunks = [lines[i:i+200] for i in range(0, len(lines), 200)] | |
| prompts = [f"Summarize:\\n{{chr(10).join(c)}}" for c in chunks] | |
| summaries = llm_query_batched(prompts) | |
| final = llm_query("Combine these summaries into 5 bullet points:\\n" + "\\n".join(summaries)) | |
| ``` | |
| ### Strategy 6: HYBRID | |
| Mix grep + summarize: use regex to reduce, then summarize the reduced set. | |
| ### DON'T WASTE LLM CALLS ON: | |
| - Counting (use Python: `len()`, `sum()`) | |
| - Pattern matching (use Python: `in`, `re.search()`) | |
| - Filtering (use Python: list comprehensions) | |
| """ | |
| # ============================================================================= | |
| # RLMAgent Class | |
| # ============================================================================= | |
| class RLMAgent(CodeAgent): | |
| """ | |
| Recursive Language Model Agent. | |
| Extends CodeAgent with: | |
| - llm_query() / llm_query_batched() for sub-LLM calls | |
| - Shared call counter | |
| - RLM strategy guidance | |
| - Step callbacks for debugging | |
| """ | |
| def __init__( | |
| self, | |
| model, | |
| sub_model=None, | |
| max_llm_calls: int = 50, | |
| max_steps: int = 20, | |
| tools: list = None, | |
| verbosity_level: int = 1, | |
| track_calls: bool = True, | |
| batched_thread_safe: bool = False, | |
| **kwargs | |
| ): | |
| self.sub_model = sub_model or model | |
| self.max_llm_calls = max_llm_calls | |
| # Shared counter | |
| self._counter = LLMCallCounter(max_calls=max_llm_calls) | |
| # Create tools with shared counter | |
| self._llm_query = LLMQueryTool(model=self.sub_model, counter=self._counter) | |
| self._llm_query_batched = LLMQueryBatchedTool( | |
| model=self.sub_model, | |
| counter=self._counter, | |
| thread_safe=batched_thread_safe, | |
| ) | |
| all_tools = [self._llm_query, self._llm_query_batched] | |
| if tools: | |
| all_tools.extend(tools) | |
| # RLM instructions | |
| base_instructions = kwargs.pop('instructions', '') or '' | |
| rlm_instructions = base_instructions + RLM_INSTRUCTIONS.format(max_llm_calls=max_llm_calls) | |
| # Step callbacks | |
| user_callbacks = kwargs.pop('step_callbacks', {}) | |
| if track_calls: | |
| rlm_callback = make_rlm_step_callback(self._counter, verbose=(verbosity_level >= 2)) | |
| if isinstance(user_callbacks, dict): | |
| existing = user_callbacks.get(ActionStep, []) | |
| if not isinstance(existing, list): | |
| existing = [existing] | |
| user_callbacks[ActionStep] = existing + [rlm_callback] | |
| elif isinstance(user_callbacks, list): | |
| user_callbacks = user_callbacks + [rlm_callback] | |
| super().__init__( | |
| model=model, | |
| tools=all_tools, | |
| max_steps=max_steps, | |
| verbosity_level=verbosity_level, | |
| instructions=rlm_instructions, | |
| step_callbacks=user_callbacks, | |
| **kwargs | |
| ) | |
| def run( | |
| self, | |
| task: str, | |
| context: str = None, | |
| reset: bool = True, | |
| show_metadata: bool = True, | |
| **kwargs | |
| ): | |
| """ | |
| Run agent on task with optional large context. | |
| Args: | |
| task: The task/question | |
| context: Large text (passed as variable, not stuffed in prompt) | |
| reset: Reset call counter (default True) | |
| show_metadata: Print variable metadata (default True) | |
| **kwargs: Additional variables via additional_args | |
| """ | |
| if reset: | |
| self._counter.reset() | |
| additional_args = kwargs.pop("additional_args", {}) or {} | |
| variables_info: list[str] = [] | |
| def _record_var(name: str, value: Any, preview_chars: int = 400) -> None: | |
| info = make_variable_info(name, value, preview_chars=preview_chars) | |
| variables_info.append(info) | |
| if show_metadata: | |
| print(f"\n{'='*60}\nVariable Metadata:\n{'='*60}\n{info}\n") | |
| # Keep large data out of the prompt. Store it in state, show metadata instead. | |
| if context is not None: | |
| self.state["context"] = context | |
| _record_var("context", context) | |
| # Route other large args into state + metadata, keep small args in prompt. | |
| for k, v in kwargs.items(): | |
| if isinstance(v, str) and len(v) > 1000: | |
| self.state[k] = v | |
| _record_var(k, v, preview_chars=200) | |
| else: | |
| additional_args[k] = v | |
| if variables_info: | |
| additional_args["variables_info"] = "\n\n".join(variables_info) | |
| return super().run(task=task, additional_args=additional_args, reset=reset) | |
| @property | |
| def llm_calls_used(self) -> int: | |
| return self._counter.count | |
| @property | |
| def llm_calls_remaining(self) -> int: | |
| return self._counter.remaining | |
| # ============================================================================= | |
| # Convenience Function | |
| # ============================================================================= | |
| def create_rlm_agent( | |
| model_id: str = "gpt-4.1-nano", | |
| sub_model_id: str = None, | |
| max_llm_calls: int = 50, | |
| max_steps: int = 20, | |
| **kwargs | |
| ) -> RLMAgent: | |
| """ | |
| Convenience function to create an RLMAgent with LiteLLM. | |
| Args: | |
| model_id: Main model ID (default: gpt-4.1-nano) | |
| sub_model_id: Sub-model for llm_query (default: same as model_id) | |
| max_llm_calls: Max sub-LLM calls (default: 50) | |
| max_steps: Max iterations (default: 20) | |
| Returns: | |
| Configured RLMAgent | |
| Example: | |
| agent = create_rlm_agent("gpt-4.1", sub_model_id="gpt-4.1-nano") | |
| result = agent.run("Summarize this", context=text) | |
| """ | |
| from smolagents import LiteLLMModel | |
| model = LiteLLMModel(model_id=model_id) | |
| sub_model = LiteLLMModel(model_id=sub_model_id) if sub_model_id else None | |
| return RLMAgent( | |
| model=model, | |
| sub_model=sub_model, | |
| max_llm_calls=max_llm_calls, | |
| max_steps=max_steps, | |
| **kwargs | |
| ) | |
| # ============================================================================= | |
| # Demo | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| print("RLM v2 Module for Smolagents") | |
| print("=" * 60) | |
| print(""" | |
| This module provides RLM (Recursive Language Model) capabilities: | |
| 1. llm_query(prompt) - Single sub-LLM call for semantic analysis | |
| 2. llm_query_batched(prompts) - Parallel sub-LLM calls (8x faster) | |
| 3. RLM strategies guidance in system prompt | |
| 4. Variable metadata for large contexts | |
| 5. Shared call counter with step tracking | |
| Usage: | |
| from rlm_v2 import RLMAgent, create_rlm_agent | |
| # Option 1: Create with model object | |
| from smolagents import LiteLLMModel | |
| model = LiteLLMModel(model_id="gpt-4.1-nano") | |
| agent = RLMAgent(model=model) | |
| # Option 2: Convenience function | |
| agent = create_rlm_agent("gpt-4.1-nano") | |
| # Option 3: With cheaper sub-model for llm_query | |
| agent = create_rlm_agent("gpt-4.1", sub_model_id="gpt-4.1-nano") | |
| # Run on large context | |
| result = agent.run( | |
| task="How many entries are about 'Python'?", | |
| context=large_text | |
| ) | |
| print(f"Calls used: {agent.llm_calls_used}") | |
| print(f"Calls remaining: {agent.llm_calls_remaining}") | |
| Strategies (automatically included in prompt): | |
| - PEEK: Always explore data structure first | |
| - GREP: Use string/regex matching before LLM calls | |
| - PARTITION + MAP: Chunk and batch for semantic tasks | |
| - VERIFY: Check results before final_answer | |
| """) | |
| # Quick test of shared counter | |
| print("Testing shared counter...") | |
| counter = LLMCallCounter(max_calls=10) | |
| print(f" Initial: remaining={counter.remaining}, count={counter.count}") | |
| counter.check_and_increment(3) | |
| print(f" After +3: remaining={counter.remaining}, count={counter.count}") | |
| counter.reset() | |
| print(f" After reset: remaining={counter.remaining}, count={counter.count}") | |
| print("\nModule ready!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment