Skip to content

Instantly share code, notes, and snippets.

@cnndabbler
Created December 16, 2024 00:37
Show Gist options
  • Save cnndabbler/6dcb488f4c21d0d73730270c62e11a47 to your computer and use it in GitHub Desktop.
Save cnndabbler/6dcb488f4c21d0d73730270c62e11a47 to your computer and use it in GitHub Desktop.
Code implementing an Agentic collaboration between a student and two AI agents, the 'Math Solver Agent' and 'Math Validator Agent,' working together to solve math problems.
from openai import OpenAI
from rich import print
from rich.console import Console
from rich.panel import Panel
from rich.prompt import Prompt
from rich.markdown import Markdown
from rich.text import Text
from rich.style import Style
from swarm import Swarm, Agent
from swarm.types import Response
import sys
from datetime import datetime
import os
from dataclasses import dataclass
from typing import Optional, Dict, Any
client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
swarm_client = Swarm(client=client)
# Initialize Rich console
console = Console()
# model = "qwq"
# model = "qwen2.5-coder:32b"
model = "qwen2.5:32b"
# model = "vanilj/Phi-4" # no function calls
def print_context(prefix: str, context_variables: Dict[str, Any]) -> None:
"""Print current state of context variables"""
console.print(f"\n[bold cyan]Context Variables ({prefix}):[/bold cyan]")
console.print(Panel(
"\n".join([f"[yellow]{k}[/yellow]: {v}" for k, v in context_variables.items()]),
title="Current State",
border_style="cyan"
))
def request_validator_handoff(context_variables: Dict[str, Any]) -> Dict[str, Any]:
"""Request validation from the validator agent"""
return {
"handoff": "validator",
"reason": "Step validation needed",
"current_step": context_variables.get("current_step", 0),
"current_content": context_variables.get("current_content", "")
}
def request_solver_handoff(context_variables: Dict[str, Any]) -> Dict[str, Any]:
"""Request control be returned to the solver agent"""
return {
"handoff": "solver",
"reason": context_variables.get("handoff_reason", "Major revision needed"),
"current_step": context_variables.get("current_step", 0)
}
def create_math_solver_agent() -> Agent:
return Agent(
name="MathSolverAgent",
model=model,
instructions="""You are a step-by-step math problem solver. Your role is to guide users through solving math problems one small step at a time.
Key Guidelines:
1. Keep each step SMALL and FOCUSED - only explain ONE concept at a time
2. WAIT for user confirmation after EACH step
3. Never proceed to the next step until the user confirms understanding
4. Keep explanations brief and clear
5. If user asks for a hint, provide a small hint focused on the current step only
6. After completing each step, request validation
IMPORTANT: Only mark a solution as complete when ALL of these are true:
1. All steps of the problem have been solved
2. The final answer has been found and clearly stated
3. All work has been shown and validated
When the solution is complete:
1. Clearly state the final answer
2. Summarize all steps taken
3. End your response with "[SOLUTION COMPLETE]"
Format your responses as:
Step X: [One sentence description]
[Brief explanation - max 2-3 sentences]
[Single focused question about current step]""",
functions=[request_validator_handoff]
)
def create_validator_agent() -> Agent:
return Agent(
name="ValidatorAgent",
model=model,
instructions="""You are a mathematical validation expert who reviews solution steps and provides valuable insights.
Your role is to:
1. Verify the mathematical correctness of each step
2. Point out any potential alternative approaches
3. Add helpful insights or connections to other mathematical concepts
4. If you spot an error, clearly explain what's wrong
5. Keep your responses brief and focused
Format your responses as EXACTLY ONE of these formats:
[✓] Correct: [Brief confirmation of the step]
or
[!] Note: [Brief insight or alternative approach]
or
[×] Error: [Brief explanation of the error]
IMPORTANT: A solution is NEVER complete until ALL of these conditions are met:
1. The problem has been fully solved with ALL necessary steps shown
2. A clear, unambiguous final answer has been stated and verified
3. ALL mathematical work has been validated step by step
4. NO steps are missing or unclear
5. The solution has been tested with numerical verification
6. All algebraic manipulations are correct and complete
7. The final answer satisfies the original equation when substituted back
COMPLETION CHECKLIST:
Before marking a solution complete, verify:
- Have we solved for the unknown variable completely?
- Is the final answer explicitly stated?
- Have we verified the answer by substituting back?
- Are ALL steps shown and explained?
- Have we performed numerical verification?
Only when ALL of these conditions are met, append "[SOLUTION COMPLETE]" to your response.
NEVER mark intermediate steps as complete, even if they are correct.
If ANY step is missing or unclear, provide feedback to the solver.""",
functions=[request_solver_handoff]
)
def handle_agent_handoff(source_agent: str, target_agent: str, context_variables: Dict[str, Any]) -> Dict[str, Any]:
"""Handle the handoff between agents and update context"""
console = Console()
current_step = context_variables.get("current_step", 1)
handoff_count = context_variables.get("handoff_count", {})
print_context(f"Before Handoff ({source_agent} -> {target_agent})", context_variables)
# Initialize step tracking if not exists
if str(current_step) not in handoff_count:
handoff_count[str(current_step)] = {}
# Create handoff key and update count
handoff_key = f"{source_agent}_to_{target_agent}"
if handoff_key not in handoff_count[str(current_step)]:
handoff_count[str(current_step)][handoff_key] = 0
handoff_count[str(current_step)][handoff_key] += 1
# Check for potential loops
current_count = handoff_count[str(current_step)][handoff_key]
if current_count > 3: # Threshold for potential loop
console.print(f"\n[bold red]Warning: Detected potential loop - {handoff_key} has occurred {current_count} times in step {current_step}[/bold red]")
console.print(f"\n[dim]Debug: Handoff from {source_agent} to {target_agent} (Step {current_step}, Count: {current_count})[/dim]")
# Update context with handoff tracking
context_variables.update({
"previous_agent": source_agent,
"current_agent": target_agent,
"handoff_count": handoff_count,
"last_handoff_time": datetime.now().isoformat()
})
# Add handoff to history
handoff_history = context_variables.get("handoff_history", [])
handoff_history.append({
"step": current_step,
"from": source_agent,
"to": target_agent,
"count": current_count,
"timestamp": datetime.now().isoformat()
})
context_variables["handoff_history"] = handoff_history
print_context(f"After Handoff ({source_agent} -> {target_agent})", context_variables)
return context_variables
def process_and_print_streaming_response(response):
"""Process and print streaming response"""
full_response = ""
content = ""
last_sender = ""
solution_complete = False
context_variables = {}
for chunk in response:
if isinstance(chunk, str):
continue
if "sender" in chunk:
if chunk["sender"] != last_sender and chunk["sender"]:
if content:
console.print()
console.print(f"\n[blue]{chunk['sender']}:[/blue] ", end="")
last_sender = chunk["sender"]
if "content" in chunk and chunk["content"] is not None:
console.print(chunk["content"], end="")
sys.stdout.flush()
content += chunk["content"]
full_response += chunk["content"]
if "[SOLUTION COMPLETE]" in chunk["content"]:
solution_complete = True
if "delim" in chunk and chunk["delim"] == "end":
if content:
console.print()
content = ""
if "response" in chunk:
text = str(chunk["response"])
if text.startswith("messages="):
continue
full_response += text
console.print(text, style="blue", end="")
sys.stdout.flush()
# Extract context variables if present in chunk
if "context_variables" in chunk:
context_variables.update(chunk["context_variables"])
if content:
console.print()
try:
return Response(
messages=[{"role": "assistant", "content": full_response}],
context_variables=context_variables
)
except Exception as e:
console.print(f"[red]Error processing response: {e}[/red]")
return None
def process_validation(validator: Agent, step_content: str, problem: str, context_variables: Dict[str, Any]) -> tuple:
"""Process validation of a solution step using context variables"""
console = Console()
print_context("Before Validation", context_variables)
validation_context = handle_agent_handoff("solver", "validator", context_variables)
validation_msg = {
"role": "user",
"content": f"Validate this solution step for the problem '{problem}':\n\n{step_content}"
}
console.print("\n[dim]Debug: Starting validation process...[/dim]")
response = swarm_client.run(
agent=validator,
model_override=model,
messages=[validation_msg],
context_variables=validation_context,
stream=True,
)
response_obj = process_and_print_streaming_response(response)
has_error = False
error_message = ""
solution_complete = False
if response_obj and response_obj.messages:
validation_content = response_obj.messages[0]["content"]
validation_content = validation_content.replace("ValidatorAgent:", "").strip()
has_error = validation_content.startswith("[×]")
solution_complete = "[SOLUTION COMPLETE]" in validation_content
console.print(f"\n[dim]Debug: Validation result - Error: {has_error}, Solution Complete: {solution_complete}[/dim]")
if has_error:
error_message = validation_content
validation_context["handoff_reason"] = error_message
console.print("\n[dim]Debug: Validation failed, returning to solver[/dim]")
context_variables.update(handle_agent_handoff("validator", "solver", validation_context))
else:
if not solution_complete:
console.print("\n[dim]Debug: Validation passed, returning to solver for next step[/dim]")
context_variables.update(handle_agent_handoff("validator", "solver", validation_context))
else:
console.print("\n[dim]Debug: Solution complete, proceeding to final verification[/dim]")
context_variables.update(handle_agent_handoff("validator", "solver", validation_context))
print_context("After Validation", context_variables)
# Update context variables from response
if response_obj and response_obj.context_variables:
context_variables.update(response_obj.context_variables)
return response_obj, has_error, error_message, solution_complete
def get_solver_step(solver: Agent, messages: list, context_variables: Dict[str, Any], error_feedback: str = None) -> tuple:
"""Get next step from solver, optionally incorporating validator feedback"""
print_context("Before Getting Solver Step", context_variables)
current_messages = messages.copy()
if error_feedback:
current_messages.append({
"role": "system",
"content": f"Previous step had an error: {error_feedback}. Please revise."
})
response = swarm_client.run(
agent=solver,
model_override=model,
messages=current_messages,
context_variables=context_variables,
stream=True,
)
response_obj = process_and_print_streaming_response(response)
# Update context variables from response
if response_obj and response_obj.context_variables:
context_variables.update(response_obj.context_variables)
print_context("After Getting Solver Step", context_variables)
return response_obj
def save_solution_to_markdown(problem: str, messages: list, timestamp: datetime):
"""Save the solution steps to a markdown file"""
# Create solutions directory if it doesn't exist
os.makedirs("solutions", exist_ok=True)
# Create filename with timestamp
filename = f"solutions/solution_{timestamp.strftime('%Y%m%d_%H%M%S')}.md"
# Prepare markdown content
content = [
f"# Math Problem Solution\n",
f"## Problem Statement\n",
f"```math\n{problem}\n```\n",
f"**Date:** {timestamp.strftime('%Y-%m-%d %H:%M:%S')}\n\n",
"## Solution Process\n"
]
# Track current step number
step_num = 1
# Process messages to group solver and validator responses
current_step = []
validation_buffer = []
for msg in messages:
if msg["role"] == "user":
if "Current step" in msg["content"] or "Continue with" in msg["content"]:
# Add any pending validation before starting new step
if validation_buffer:
current_step.append("\n<details>\n<summary>🔍 Validation Feedback</summary>\n\n")
current_step.extend(validation_buffer)
current_step.append("</details>\n")
validation_buffer = []
# Add current step content and start new step
if current_step:
content.extend(current_step)
current_step = []
content.append(f"\n### Step {step_num}\n")
step_num += 1
elif "hint" in msg["content"].lower():
current_step.append("\n#### 🤔 User requested hint:\n")
elif msg["role"] == "assistant":
# Clean up the message content
msg_content = msg["content"].strip()
if any(marker in msg_content for marker in ["[✓]", "[!]", "[×]"]):
# This is a validator message - add to validation buffer
validation_buffer.append(f"{msg_content}\n\n")
else:
# If there's pending validation, add it before the next solver message
if validation_buffer:
current_step.append("\n<details>\n<summary>🔍 Validation Feedback</summary>\n\n")
current_step.extend(validation_buffer)
current_step.append("</details>\n\n")
validation_buffer = []
# This is a solver message
current_step.append("\n📝 **Solution Step:**\n")
current_step.append(f"{msg_content}\n")
# Add any remaining step and validation
if validation_buffer:
current_step.append("\n<details>\n<summary>🔍 Validation Feedback</summary>\n\n")
current_step.extend(validation_buffer)
current_step.append("</details>\n")
if current_step:
content.extend(current_step)
# Add summary section
content.extend([
"\n## Summary\n",
"### Key Steps:\n"
])
# Extract step descriptions for summary
step_descriptions = []
for msg in messages:
if msg["role"] == "assistant":
msg_content = msg["content"].strip()
if msg_content.startswith("Step") and not any(marker in msg_content for marker in ["[✓]", "[!]", "[×]"]):
# Get the first line containing "Step X:" and any immediate explanation
step_lines = msg_content.split('\n')
step_desc = step_lines[0] # The "Step X:" line
if len(step_lines) > 1 and step_lines[1].strip(): # If there's an explanation on the next line
step_desc += f"\n - {step_lines[1].strip()}"
step_descriptions.append(f"- {step_desc}\n")
content.extend(step_descriptions)
# Add final validation status if solution was completed
if any("[SOLUTION COMPLETE]" in msg["content"] for msg in messages):
content.append("\n### ✅ Solution Complete\n")
# Add final validation if present
final_validations = [msg["content"] for msg in messages[-3:]
if msg["role"] == "assistant" and any(marker in msg["content"] for marker in ["[✓]", "[!]", "[×]"])]
if final_validations:
content.append("**Final Validation:**\n")
content.extend(f"> {validation}\n" for validation in final_validations)
# Write to file
with open(filename, "w") as f:
f.writelines(content)
return filename
def interactive_problem_solving(problem: str):
solver = create_math_solver_agent()
validator = create_validator_agent()
# Initialize context variables for state management
context_variables = {
"current_step": 1,
"problem": problem,
"active_agent": "solver",
"last_validation": None,
"solution_complete": False,
"handoff_count": {}, # Initialize as empty dict
"handoff_history": []
}
print_context("Initial State", context_variables)
messages = []
solution_timestamp = datetime.now()
# Print welcome message
console.print(Panel.fit(
f"[bold blue]Let's solve this problem step by step:[/bold blue]\n[yellow]{problem}[/yellow]",
title="Math Problem Solver",
border_style="blue"
))
console.print()
while True:
try:
# Prepare current message based on context
if context_variables['current_step'] == 1:
current_message = {
"role": "user",
"content": f"Current step {context_variables['current_step']}: {problem}"
}
else:
current_message = {
"role": "user",
"content": "Continue with the next step"
}
messages.append(current_message)
# Get streaming response from solver (with potential retries)
error_feedback = None
while True:
# Update context for solver
context_variables["active_agent"] = "solver"
# Get solver's step
response_obj = get_solver_step(solver, messages, context_variables, error_feedback)
if not response_obj:
break
# Process solver's response
if response_obj.messages:
# Store current step content in context
context_variables["current_content"] = response_obj.messages[-1]["content"]
# Validate the step
validation_obj, has_error, error_message, solution_complete = process_validation(
validator,
response_obj.messages[-1]["content"],
problem,
context_variables
)
# If validation fails, retry with feedback
if has_error:
error_feedback = error_message
continue
# Step is valid
messages.extend(response_obj.messages)
if validation_obj:
messages.extend(validation_obj.messages)
# If solution is complete, save and exit
if solution_complete:
# Save the solution
solution_file = save_solution_to_markdown(problem, messages, solution_timestamp)
console.print(Panel.fit(
"[bold green]✓ Problem solved successfully![/bold green]\n" +
f"[blue]Solution saved to: [/blue][yellow]{solution_file}[/yellow]",
border_style="green"
))
return
break
# Get user input with styled prompt
user_input = Prompt.ask(
"\nDo you understand and want to continue?",
choices=["yes", "no", "hint", "quit"],
default="yes"
).lower()
if user_input == 'quit':
console.print(Panel("[yellow]Ending the problem-solving session.[/yellow]", border_style="red"))
break
elif user_input == 'no':
messages.append({"role": "user", "content": "I don't understand. Please explain the current step again."})
continue
elif user_input == 'hint':
hint_message = {"role": "user", "content": "Please provide a hint for the current step"}
messages.append(hint_message)
console.print("\n[bold green]Hint:[/bold green]")
# Get streaming response for hint
hint_response = swarm_client.run(
agent=solver,
model_override=model,
messages=messages,
context_variables=context_variables,
stream=True,
# debug=True,
)
hint_obj = process_and_print_streaming_response(hint_response)
if hint_obj and hint_obj.messages:
messages.extend(hint_obj.messages)
continue
# Increment step counter for next iteration
context_variables['current_step'] += 1
except Exception as e:
console.print(f"\n[bold red]Error: {str(e)}[/bold red]")
console.print(Panel("[yellow]An error occurred. Would you like to try again?[/yellow]", border_style="red"))
retry = Prompt.ask("Try again?", choices=["yes", "no"], default="yes").lower()
if retry == "no":
break
continue
if __name__ == "__main__":
try:
problem = "solve 8^k + 2^k = 130"
interactive_problem_solving(problem)
except KeyboardInterrupt:
console.print("\n[yellow]Session interrupted by user.[/yellow]")
except Exception as e:
console.print(f"\n[red bold]Fatal error:[/red bold] {str(e)}", style="red")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment