Created
December 16, 2024 00:37
-
-
Save cnndabbler/6dcb488f4c21d0d73730270c62e11a47 to your computer and use it in GitHub Desktop.
Code implementing an Agentic collaboration between a student and two AI agents, the 'Math Solver Agent' and 'Math Validator Agent,' working together to solve math problems.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from openai import OpenAI | |
from rich import print | |
from rich.console import Console | |
from rich.panel import Panel | |
from rich.prompt import Prompt | |
from rich.markdown import Markdown | |
from rich.text import Text | |
from rich.style import Style | |
from swarm import Swarm, Agent | |
from swarm.types import Response | |
import sys | |
from datetime import datetime | |
import os | |
from dataclasses import dataclass | |
from typing import Optional, Dict, Any | |
client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama") | |
swarm_client = Swarm(client=client) | |
# Initialize Rich console | |
console = Console() | |
# model = "qwq" | |
# model = "qwen2.5-coder:32b" | |
model = "qwen2.5:32b" | |
# model = "vanilj/Phi-4" # no function calls | |
def print_context(prefix: str, context_variables: Dict[str, Any]) -> None: | |
"""Print current state of context variables""" | |
console.print(f"\n[bold cyan]Context Variables ({prefix}):[/bold cyan]") | |
console.print(Panel( | |
"\n".join([f"[yellow]{k}[/yellow]: {v}" for k, v in context_variables.items()]), | |
title="Current State", | |
border_style="cyan" | |
)) | |
def request_validator_handoff(context_variables: Dict[str, Any]) -> Dict[str, Any]: | |
"""Request validation from the validator agent""" | |
return { | |
"handoff": "validator", | |
"reason": "Step validation needed", | |
"current_step": context_variables.get("current_step", 0), | |
"current_content": context_variables.get("current_content", "") | |
} | |
def request_solver_handoff(context_variables: Dict[str, Any]) -> Dict[str, Any]: | |
"""Request control be returned to the solver agent""" | |
return { | |
"handoff": "solver", | |
"reason": context_variables.get("handoff_reason", "Major revision needed"), | |
"current_step": context_variables.get("current_step", 0) | |
} | |
def create_math_solver_agent() -> Agent: | |
return Agent( | |
name="MathSolverAgent", | |
model=model, | |
instructions="""You are a step-by-step math problem solver. Your role is to guide users through solving math problems one small step at a time. | |
Key Guidelines: | |
1. Keep each step SMALL and FOCUSED - only explain ONE concept at a time | |
2. WAIT for user confirmation after EACH step | |
3. Never proceed to the next step until the user confirms understanding | |
4. Keep explanations brief and clear | |
5. If user asks for a hint, provide a small hint focused on the current step only | |
6. After completing each step, request validation | |
IMPORTANT: Only mark a solution as complete when ALL of these are true: | |
1. All steps of the problem have been solved | |
2. The final answer has been found and clearly stated | |
3. All work has been shown and validated | |
When the solution is complete: | |
1. Clearly state the final answer | |
2. Summarize all steps taken | |
3. End your response with "[SOLUTION COMPLETE]" | |
Format your responses as: | |
Step X: [One sentence description] | |
[Brief explanation - max 2-3 sentences] | |
[Single focused question about current step]""", | |
functions=[request_validator_handoff] | |
) | |
def create_validator_agent() -> Agent: | |
return Agent( | |
name="ValidatorAgent", | |
model=model, | |
instructions="""You are a mathematical validation expert who reviews solution steps and provides valuable insights. | |
Your role is to: | |
1. Verify the mathematical correctness of each step | |
2. Point out any potential alternative approaches | |
3. Add helpful insights or connections to other mathematical concepts | |
4. If you spot an error, clearly explain what's wrong | |
5. Keep your responses brief and focused | |
Format your responses as EXACTLY ONE of these formats: | |
[✓] Correct: [Brief confirmation of the step] | |
or | |
[!] Note: [Brief insight or alternative approach] | |
or | |
[×] Error: [Brief explanation of the error] | |
IMPORTANT: A solution is NEVER complete until ALL of these conditions are met: | |
1. The problem has been fully solved with ALL necessary steps shown | |
2. A clear, unambiguous final answer has been stated and verified | |
3. ALL mathematical work has been validated step by step | |
4. NO steps are missing or unclear | |
5. The solution has been tested with numerical verification | |
6. All algebraic manipulations are correct and complete | |
7. The final answer satisfies the original equation when substituted back | |
COMPLETION CHECKLIST: | |
Before marking a solution complete, verify: | |
- Have we solved for the unknown variable completely? | |
- Is the final answer explicitly stated? | |
- Have we verified the answer by substituting back? | |
- Are ALL steps shown and explained? | |
- Have we performed numerical verification? | |
Only when ALL of these conditions are met, append "[SOLUTION COMPLETE]" to your response. | |
NEVER mark intermediate steps as complete, even if they are correct. | |
If ANY step is missing or unclear, provide feedback to the solver.""", | |
functions=[request_solver_handoff] | |
) | |
def handle_agent_handoff(source_agent: str, target_agent: str, context_variables: Dict[str, Any]) -> Dict[str, Any]: | |
"""Handle the handoff between agents and update context""" | |
console = Console() | |
current_step = context_variables.get("current_step", 1) | |
handoff_count = context_variables.get("handoff_count", {}) | |
print_context(f"Before Handoff ({source_agent} -> {target_agent})", context_variables) | |
# Initialize step tracking if not exists | |
if str(current_step) not in handoff_count: | |
handoff_count[str(current_step)] = {} | |
# Create handoff key and update count | |
handoff_key = f"{source_agent}_to_{target_agent}" | |
if handoff_key not in handoff_count[str(current_step)]: | |
handoff_count[str(current_step)][handoff_key] = 0 | |
handoff_count[str(current_step)][handoff_key] += 1 | |
# Check for potential loops | |
current_count = handoff_count[str(current_step)][handoff_key] | |
if current_count > 3: # Threshold for potential loop | |
console.print(f"\n[bold red]Warning: Detected potential loop - {handoff_key} has occurred {current_count} times in step {current_step}[/bold red]") | |
console.print(f"\n[dim]Debug: Handoff from {source_agent} to {target_agent} (Step {current_step}, Count: {current_count})[/dim]") | |
# Update context with handoff tracking | |
context_variables.update({ | |
"previous_agent": source_agent, | |
"current_agent": target_agent, | |
"handoff_count": handoff_count, | |
"last_handoff_time": datetime.now().isoformat() | |
}) | |
# Add handoff to history | |
handoff_history = context_variables.get("handoff_history", []) | |
handoff_history.append({ | |
"step": current_step, | |
"from": source_agent, | |
"to": target_agent, | |
"count": current_count, | |
"timestamp": datetime.now().isoformat() | |
}) | |
context_variables["handoff_history"] = handoff_history | |
print_context(f"After Handoff ({source_agent} -> {target_agent})", context_variables) | |
return context_variables | |
def process_and_print_streaming_response(response): | |
"""Process and print streaming response""" | |
full_response = "" | |
content = "" | |
last_sender = "" | |
solution_complete = False | |
context_variables = {} | |
for chunk in response: | |
if isinstance(chunk, str): | |
continue | |
if "sender" in chunk: | |
if chunk["sender"] != last_sender and chunk["sender"]: | |
if content: | |
console.print() | |
console.print(f"\n[blue]{chunk['sender']}:[/blue] ", end="") | |
last_sender = chunk["sender"] | |
if "content" in chunk and chunk["content"] is not None: | |
console.print(chunk["content"], end="") | |
sys.stdout.flush() | |
content += chunk["content"] | |
full_response += chunk["content"] | |
if "[SOLUTION COMPLETE]" in chunk["content"]: | |
solution_complete = True | |
if "delim" in chunk and chunk["delim"] == "end": | |
if content: | |
console.print() | |
content = "" | |
if "response" in chunk: | |
text = str(chunk["response"]) | |
if text.startswith("messages="): | |
continue | |
full_response += text | |
console.print(text, style="blue", end="") | |
sys.stdout.flush() | |
# Extract context variables if present in chunk | |
if "context_variables" in chunk: | |
context_variables.update(chunk["context_variables"]) | |
if content: | |
console.print() | |
try: | |
return Response( | |
messages=[{"role": "assistant", "content": full_response}], | |
context_variables=context_variables | |
) | |
except Exception as e: | |
console.print(f"[red]Error processing response: {e}[/red]") | |
return None | |
def process_validation(validator: Agent, step_content: str, problem: str, context_variables: Dict[str, Any]) -> tuple: | |
"""Process validation of a solution step using context variables""" | |
console = Console() | |
print_context("Before Validation", context_variables) | |
validation_context = handle_agent_handoff("solver", "validator", context_variables) | |
validation_msg = { | |
"role": "user", | |
"content": f"Validate this solution step for the problem '{problem}':\n\n{step_content}" | |
} | |
console.print("\n[dim]Debug: Starting validation process...[/dim]") | |
response = swarm_client.run( | |
agent=validator, | |
model_override=model, | |
messages=[validation_msg], | |
context_variables=validation_context, | |
stream=True, | |
) | |
response_obj = process_and_print_streaming_response(response) | |
has_error = False | |
error_message = "" | |
solution_complete = False | |
if response_obj and response_obj.messages: | |
validation_content = response_obj.messages[0]["content"] | |
validation_content = validation_content.replace("ValidatorAgent:", "").strip() | |
has_error = validation_content.startswith("[×]") | |
solution_complete = "[SOLUTION COMPLETE]" in validation_content | |
console.print(f"\n[dim]Debug: Validation result - Error: {has_error}, Solution Complete: {solution_complete}[/dim]") | |
if has_error: | |
error_message = validation_content | |
validation_context["handoff_reason"] = error_message | |
console.print("\n[dim]Debug: Validation failed, returning to solver[/dim]") | |
context_variables.update(handle_agent_handoff("validator", "solver", validation_context)) | |
else: | |
if not solution_complete: | |
console.print("\n[dim]Debug: Validation passed, returning to solver for next step[/dim]") | |
context_variables.update(handle_agent_handoff("validator", "solver", validation_context)) | |
else: | |
console.print("\n[dim]Debug: Solution complete, proceeding to final verification[/dim]") | |
context_variables.update(handle_agent_handoff("validator", "solver", validation_context)) | |
print_context("After Validation", context_variables) | |
# Update context variables from response | |
if response_obj and response_obj.context_variables: | |
context_variables.update(response_obj.context_variables) | |
return response_obj, has_error, error_message, solution_complete | |
def get_solver_step(solver: Agent, messages: list, context_variables: Dict[str, Any], error_feedback: str = None) -> tuple: | |
"""Get next step from solver, optionally incorporating validator feedback""" | |
print_context("Before Getting Solver Step", context_variables) | |
current_messages = messages.copy() | |
if error_feedback: | |
current_messages.append({ | |
"role": "system", | |
"content": f"Previous step had an error: {error_feedback}. Please revise." | |
}) | |
response = swarm_client.run( | |
agent=solver, | |
model_override=model, | |
messages=current_messages, | |
context_variables=context_variables, | |
stream=True, | |
) | |
response_obj = process_and_print_streaming_response(response) | |
# Update context variables from response | |
if response_obj and response_obj.context_variables: | |
context_variables.update(response_obj.context_variables) | |
print_context("After Getting Solver Step", context_variables) | |
return response_obj | |
def save_solution_to_markdown(problem: str, messages: list, timestamp: datetime): | |
"""Save the solution steps to a markdown file""" | |
# Create solutions directory if it doesn't exist | |
os.makedirs("solutions", exist_ok=True) | |
# Create filename with timestamp | |
filename = f"solutions/solution_{timestamp.strftime('%Y%m%d_%H%M%S')}.md" | |
# Prepare markdown content | |
content = [ | |
f"# Math Problem Solution\n", | |
f"## Problem Statement\n", | |
f"```math\n{problem}\n```\n", | |
f"**Date:** {timestamp.strftime('%Y-%m-%d %H:%M:%S')}\n\n", | |
"## Solution Process\n" | |
] | |
# Track current step number | |
step_num = 1 | |
# Process messages to group solver and validator responses | |
current_step = [] | |
validation_buffer = [] | |
for msg in messages: | |
if msg["role"] == "user": | |
if "Current step" in msg["content"] or "Continue with" in msg["content"]: | |
# Add any pending validation before starting new step | |
if validation_buffer: | |
current_step.append("\n<details>\n<summary>🔍 Validation Feedback</summary>\n\n") | |
current_step.extend(validation_buffer) | |
current_step.append("</details>\n") | |
validation_buffer = [] | |
# Add current step content and start new step | |
if current_step: | |
content.extend(current_step) | |
current_step = [] | |
content.append(f"\n### Step {step_num}\n") | |
step_num += 1 | |
elif "hint" in msg["content"].lower(): | |
current_step.append("\n#### 🤔 User requested hint:\n") | |
elif msg["role"] == "assistant": | |
# Clean up the message content | |
msg_content = msg["content"].strip() | |
if any(marker in msg_content for marker in ["[✓]", "[!]", "[×]"]): | |
# This is a validator message - add to validation buffer | |
validation_buffer.append(f"{msg_content}\n\n") | |
else: | |
# If there's pending validation, add it before the next solver message | |
if validation_buffer: | |
current_step.append("\n<details>\n<summary>🔍 Validation Feedback</summary>\n\n") | |
current_step.extend(validation_buffer) | |
current_step.append("</details>\n\n") | |
validation_buffer = [] | |
# This is a solver message | |
current_step.append("\n📝 **Solution Step:**\n") | |
current_step.append(f"{msg_content}\n") | |
# Add any remaining step and validation | |
if validation_buffer: | |
current_step.append("\n<details>\n<summary>🔍 Validation Feedback</summary>\n\n") | |
current_step.extend(validation_buffer) | |
current_step.append("</details>\n") | |
if current_step: | |
content.extend(current_step) | |
# Add summary section | |
content.extend([ | |
"\n## Summary\n", | |
"### Key Steps:\n" | |
]) | |
# Extract step descriptions for summary | |
step_descriptions = [] | |
for msg in messages: | |
if msg["role"] == "assistant": | |
msg_content = msg["content"].strip() | |
if msg_content.startswith("Step") and not any(marker in msg_content for marker in ["[✓]", "[!]", "[×]"]): | |
# Get the first line containing "Step X:" and any immediate explanation | |
step_lines = msg_content.split('\n') | |
step_desc = step_lines[0] # The "Step X:" line | |
if len(step_lines) > 1 and step_lines[1].strip(): # If there's an explanation on the next line | |
step_desc += f"\n - {step_lines[1].strip()}" | |
step_descriptions.append(f"- {step_desc}\n") | |
content.extend(step_descriptions) | |
# Add final validation status if solution was completed | |
if any("[SOLUTION COMPLETE]" in msg["content"] for msg in messages): | |
content.append("\n### ✅ Solution Complete\n") | |
# Add final validation if present | |
final_validations = [msg["content"] for msg in messages[-3:] | |
if msg["role"] == "assistant" and any(marker in msg["content"] for marker in ["[✓]", "[!]", "[×]"])] | |
if final_validations: | |
content.append("**Final Validation:**\n") | |
content.extend(f"> {validation}\n" for validation in final_validations) | |
# Write to file | |
with open(filename, "w") as f: | |
f.writelines(content) | |
return filename | |
def interactive_problem_solving(problem: str): | |
solver = create_math_solver_agent() | |
validator = create_validator_agent() | |
# Initialize context variables for state management | |
context_variables = { | |
"current_step": 1, | |
"problem": problem, | |
"active_agent": "solver", | |
"last_validation": None, | |
"solution_complete": False, | |
"handoff_count": {}, # Initialize as empty dict | |
"handoff_history": [] | |
} | |
print_context("Initial State", context_variables) | |
messages = [] | |
solution_timestamp = datetime.now() | |
# Print welcome message | |
console.print(Panel.fit( | |
f"[bold blue]Let's solve this problem step by step:[/bold blue]\n[yellow]{problem}[/yellow]", | |
title="Math Problem Solver", | |
border_style="blue" | |
)) | |
console.print() | |
while True: | |
try: | |
# Prepare current message based on context | |
if context_variables['current_step'] == 1: | |
current_message = { | |
"role": "user", | |
"content": f"Current step {context_variables['current_step']}: {problem}" | |
} | |
else: | |
current_message = { | |
"role": "user", | |
"content": "Continue with the next step" | |
} | |
messages.append(current_message) | |
# Get streaming response from solver (with potential retries) | |
error_feedback = None | |
while True: | |
# Update context for solver | |
context_variables["active_agent"] = "solver" | |
# Get solver's step | |
response_obj = get_solver_step(solver, messages, context_variables, error_feedback) | |
if not response_obj: | |
break | |
# Process solver's response | |
if response_obj.messages: | |
# Store current step content in context | |
context_variables["current_content"] = response_obj.messages[-1]["content"] | |
# Validate the step | |
validation_obj, has_error, error_message, solution_complete = process_validation( | |
validator, | |
response_obj.messages[-1]["content"], | |
problem, | |
context_variables | |
) | |
# If validation fails, retry with feedback | |
if has_error: | |
error_feedback = error_message | |
continue | |
# Step is valid | |
messages.extend(response_obj.messages) | |
if validation_obj: | |
messages.extend(validation_obj.messages) | |
# If solution is complete, save and exit | |
if solution_complete: | |
# Save the solution | |
solution_file = save_solution_to_markdown(problem, messages, solution_timestamp) | |
console.print(Panel.fit( | |
"[bold green]✓ Problem solved successfully![/bold green]\n" + | |
f"[blue]Solution saved to: [/blue][yellow]{solution_file}[/yellow]", | |
border_style="green" | |
)) | |
return | |
break | |
# Get user input with styled prompt | |
user_input = Prompt.ask( | |
"\nDo you understand and want to continue?", | |
choices=["yes", "no", "hint", "quit"], | |
default="yes" | |
).lower() | |
if user_input == 'quit': | |
console.print(Panel("[yellow]Ending the problem-solving session.[/yellow]", border_style="red")) | |
break | |
elif user_input == 'no': | |
messages.append({"role": "user", "content": "I don't understand. Please explain the current step again."}) | |
continue | |
elif user_input == 'hint': | |
hint_message = {"role": "user", "content": "Please provide a hint for the current step"} | |
messages.append(hint_message) | |
console.print("\n[bold green]Hint:[/bold green]") | |
# Get streaming response for hint | |
hint_response = swarm_client.run( | |
agent=solver, | |
model_override=model, | |
messages=messages, | |
context_variables=context_variables, | |
stream=True, | |
# debug=True, | |
) | |
hint_obj = process_and_print_streaming_response(hint_response) | |
if hint_obj and hint_obj.messages: | |
messages.extend(hint_obj.messages) | |
continue | |
# Increment step counter for next iteration | |
context_variables['current_step'] += 1 | |
except Exception as e: | |
console.print(f"\n[bold red]Error: {str(e)}[/bold red]") | |
console.print(Panel("[yellow]An error occurred. Would you like to try again?[/yellow]", border_style="red")) | |
retry = Prompt.ask("Try again?", choices=["yes", "no"], default="yes").lower() | |
if retry == "no": | |
break | |
continue | |
if __name__ == "__main__": | |
try: | |
problem = "solve 8^k + 2^k = 130" | |
interactive_problem_solving(problem) | |
except KeyboardInterrupt: | |
console.print("\n[yellow]Session interrupted by user.[/yellow]") | |
except Exception as e: | |
console.print(f"\n[red bold]Fatal error:[/red bold] {str(e)}", style="red") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment