Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Last active February 17, 2025 01:53
Show Gist options
  • Save grahama1970/2f3e625c4c91eb69ccc024290338ec8a to your computer and use it in GitHub Desktop.
Save grahama1970/2f3e625c4c91eb69ccc024290338ec8a to your computer and use it in GitHub Desktop.
Student-Teacher-GRPO-Proof-of-Concept
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Enhanced Unsloth GRPO with Student-Teacher Reward Mechanism
------------------------------------------------------------
This script extends Unsloth's GRPO by implementing a novel student-teacher reward
mechanism for improved reasoning chains.
Key Components:
1. Student Model (Unsloth GRPO):
- Generates initial reasoning and answers.
- Incorporates teacher hints through iterative "Aha!" moments.
- Uses sampling (not argmax) for diversity.
2. Teacher Model (via LiteLLM or locally):
- Provides subtle hints (without revealing the final answer).
- Can be invoked locally or via an external API.
- Teacher hints are tokenized to ensure compatibility.
3. Reward Mechanism:
- Base reward for correct answers (REWARD_BASE = 2.0).
- Quadratic penalty for extra iterations with a coefficient (PENALTY_COEFFICIENT = 0.05).
- Minimum reward floor (0.5).
Usage:
1. Configure models (set teacher_model_id and teacher_api_url if needed).
2. Run: python unsloth_student_teacher_training.py
3. Monitor logs in "training_{time}.log"
Requirements:
- Python 3.10+
- Unsloth, LiteLLM, VLLM, Loguru, Requests, Tenacity, Datasets, PyTorch with CUDA.
Author: Graham Anderson
Date: 2025-02-17
License: Apache 2.0
"""
import os
import sys
import re
import time
import torch
from datasets import load_dataset
from loguru import logger
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from vllm import SamplingParams
from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
from trl import GRPOConfig, GRPOTrainer
import requests
import asyncio
from typing import List
# --- Configuration Constants ---
REWARD_BASE = 2.0
PENALTY_COEFFICIENT = 0.05 # Lower than 0.1 to avoid over-penalizing multi-step reasoning
MIN_REWARD = 0.5
DEBUG_CONFIG = {
"ENABLE_DEBUG": os.getenv("ENABLE_DEBUG", "false").lower() == "true",
"DEBUG_SAMPLES": int(os.getenv("DEBUG_SAMPLES", "2")),
"LOG_LEVEL": os.getenv("LOG_LEVEL", "INFO"),
}
# Configure Loguru logger.
logger.add(
"training_{time}.log",
rotation="100 MB",
retention="10 days",
level=DEBUG_CONFIG["LOG_LEVEL"],
backtrace=True,
diagnose=True,
enqueue=True,
)
# Apply the GRPO patch.
PatchFastRL("GRPO", FastLanguageModel)
# ---------------------- Helper Functions ----------------------
def extract_xml_answer(text: str) -> str:
"""Extract the answer from XML-formatted text."""
try:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
except Exception as e:
logger.error(f"XML parsing error: {e} | Text snippet: {text[:100]}...")
return ""
def _calculate_reward(iterations: int, is_correct: bool) -> float:
"""
Calculate reward based on iterations.
Uses a quadratic penalty with a lower coefficient to not over-penalize multi-step reasoning.
"""
base = REWARD_BASE if is_correct else 1.0
penalty = (iterations ** 2) * PENALTY_COEFFICIENT
return max(base - penalty, MIN_REWARD)
# Async retry decorator for teacher hints using LiteLLM.
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(Exception),
before_sleep=lambda retry_state: logger.warning(
f"Retrying teacher hint (attempt {retry_state.attempt_number})"
)
)
async def get_single_teacher_hint(student_ans: str, teacher_model) -> str:
"""
Fetch a single teacher hint asynchronously using LiteLLM's API.
Returns a hint string starting with "Aha!".
"""
try:
from litellm import acompletion
response = await acompletion(
model=teacher_model,
messages=[{
"role": "user",
"content": (
f"Student answer: '{student_ans}'. "
"Provide a subtle hint (do not reveal the final answer) "
"to guide improvement. Begin with 'Aha!'."
)
}],
temperature=0.8,
top_p=0.95,
max_tokens=50
)
hint = response.choices[0].message.content.strip()
if not hint.startswith("Aha!"):
hint = f"Aha! {hint}"
return hint
except Exception as e:
logger.error(f"Teacher hint generation failed: {e}")
raise
async def get_batch_teacher_hints(student_answers: list[str], teacher_model) -> list[str]:
"""
Fetch teacher hints for a batch of student answers concurrently.
"""
tasks = [get_single_teacher_hint(ans, teacher_model) for ans in student_answers]
hints = await asyncio.gather(*tasks, return_exceptions=True)
processed = []
for idx, hint in enumerate(hints):
if isinstance(hint, Exception):
logger.error(f"Hint {idx} failed: {hint}")
processed.append("Aha! Please reconsider your approach.")
else:
processed.append(hint)
logger.debug(f"Obtained {len(processed)} teacher hints.")
return processed
class TeacherHintManager:
def __init__(self, teacher_model):
self.loop = asyncio.new_event_loop()
self.teacher_model = teacher_model
def get_hints(self, student_answers: list[str]) -> list[str]:
return self.loop.run_until_complete(get_batch_teacher_hints(student_answers, self.teacher_model))
def tokenize_teacher_hint(hint: str, teacher_tokenizer) -> str:
"""
Ensures the teacher hint is tokenized and decoded with the teacher tokenizer,
to align its format with the student model.
"""
try:
tokens = teacher_tokenizer.encode(hint, add_special_tokens=False)
decoded_hint = teacher_tokenizer.decode(tokens, skip_special_tokens=True)
return decoded_hint.strip()
except Exception as e:
logger.error(f"Error tokenizing teacher hint: {e}")
return hint
# ---------------------- Reward Functions ----------------------
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_xml_answer(r) for r in responses]
return [2.0 if s.strip() == a.strip() else 0.0 for s, a in zip(extracted, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_xml_answer(r) for r in responses]
return [0.5 if s.strip().isdigit() else 0.0 for s in extracted]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [comp[0]["content"] for comp in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if m else 0.0 for m in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [comp[0]["content"] for comp in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if m else 0.0 for m in matches]
def count_xml(text: str) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1]) * 0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [comp[0]["content"] for comp in completions]
return [count_xml(c) for c in contents]
def iterative_teacher_hint_reward_func(
prompts,
completions,
answer,
teacher_model,
teacher_tokenizer,
student_model,
max_iterations=3,
**kwargs
) -> list[float]:
"""
Iteratively evaluates the student's answer and, if incorrect, obtains a teacher hint.
The teacher backend can be "local" or "external" (using LiteLLM).
Teacher hints are tokenized to ensure compatibility.
The teacher hint is appended to the student's chain-of-thought, and the student model re-generates output.
A quadratic penalty (with a milder coefficient) is applied per extra iteration.
"""
rewards = []
teacher_backend = kwargs.get("teacher_backend", "local")
teacher_api_url = kwargs.get("teacher_api_url", None)
teacher_api_key = kwargs.get("teacher_api_key", None)
# Collect initial student answers.
student_answers = [extract_xml_answer(comp[0]['content']) for comp in completions]
hint_manager = TeacherHintManager(teacher_model)
try:
teacher_hints = hint_manager.get_hints(student_answers)
except Exception as e:
logger.error(f"Failed to get initial teacher hints: {e}")
teacher_hints = ["Aha! Please reconsider your approach."] * len(student_answers)
for idx, (comp, gt, init_hint) in enumerate(zip(completions, answer, teacher_hints)):
try:
student_output = comp[0]['content']
iterations = 0
reasoning_chain: List[str] = []
while iterations < max_iterations:
student_ans = extract_xml_answer(student_output)
if student_ans.strip() == gt.strip():
logger.info(f"Sample {idx}: Correct answer achieved after {iterations} iterations.")
break
# Prepare teacher prompt.
teacher_prompt = (
f"Student answer: '{student_ans}'. "
"Provide a subtle hint (do not reveal the final answer) to guide improvement. "
"Begin your hint with 'Aha!'."
)
if teacher_backend == "local":
try:
hint_raw = teacher_model.fast_generate(
[teacher_prompt],
sampling_params=SamplingParams(temperature=0.8, top_p=0.95, max_tokens=50),
lora_request=None
)[0].outputs[0].text.strip()
except Exception as e:
logger.error(f"Local teacher hint generation error (sample {idx}): {e}")
hint_raw = "Aha! Please reconsider your approach."
elif teacher_backend == "external" and teacher_api_url is not None:
hint_raw = call_teacher_hint_litelm(teacher_prompt, teacher_api_url, teacher_api_key)
else:
hint_raw = "Aha! Please reconsider your approach."
# Tokenize the teacher hint to ensure compatibility.
teacher_hint = tokenize_teacher_hint(hint_raw, teacher_tokenizer)
reasoning_chain.append(teacher_hint)
logger.debug(f"Sample {idx} - Iteration {iterations+1} teacher hint: {teacher_hint}")
# Append the teacher's hint to the student output.
student_output = student_output + "\n" + teacher_hint
# Re-generate student output using sampling (not argmax) for diversity.
try:
new_output = student_model.fast_generate(
[student_output],
sampling_params=SamplingParams(temperature=0.8, top_p=0.95, max_tokens=200),
lora_request=None
)[0].outputs[0].text
except Exception as e:
logger.error(f"Student re-generation error (sample {idx}, iteration {iterations+1}): {e}")
new_output = student_output
student_output = new_output
iterations += 1
final_ans = extract_xml_answer(student_output)
is_correct = final_ans.strip() == gt.strip()
reward = _calculate_reward(iterations, is_correct)
logger.info(f"Sample {idx}: Final answer '{final_ans}' | Reward: {reward:.2f} after {iterations} iterations")
logger.debug("Reasoning chain:\n" + "\n".join(reasoning_chain))
rewards.append(reward)
except Exception as e:
logger.exception(f"Error processing sample {idx}: {e}")
rewards.append(MIN_REWARD)
torch.cuda.empty_cache()
return rewards
# ---------------------- Data Preparation ----------------------
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
def get_gsm8k_questions(split="train") -> "Dataset":
"""Load and format the GSM8K dataset."""
try:
data = load_dataset('openai/gsm8k', 'main')[split]
data = data.map(lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': x['answer'].split("####")[1].strip()
})
logger.info(f"Loaded {len(data)} examples from GSM8K ({split}) split.")
return data
except Exception as e:
logger.error(f"Failed to load GSM8K dataset: {e}")
raise
dataset = get_gsm8k_questions()
# ---------------------- GRPO Training Configuration ----------------------
training_args = GRPOConfig(
use_vllm=True,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
num_generations=8,
max_prompt_length=256,
max_completion_length=200,
max_steps=250,
save_steps=250,
max_grad_norm=0.1,
report_to="none",
output_dir="outputs",
)
# ---------------------- Main Function ----------------------
def main():
try:
logger.info("Starting GRPO training with iterative teacher hints...")
# Load dataset and optionally restrict samples in debug mode.
dataset_full = get_gsm8k_questions()
if DEBUG_CONFIG["ENABLE_DEBUG"]:
dataset_full = dataset_full.select(range(DEBUG_CONFIG["DEBUG_SAMPLES"]))
logger.warning(f"Debug mode enabled: using {DEBUG_CONFIG['DEBUG_SAMPLES']} samples.")
# Initialize GRPOTrainer with our iterative teacher hint reward function.
trainer = GRPOTrainer(
model=student_model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
iterative_teacher_hint_reward_func, # Our new iterative reward function.
],
args=training_args,
train_dataset=dataset_full,
eval_dataset=dataset_full,
reward_kwargs={
"teacher_model": teacher_model,
"teacher_tokenizer": teacher_tokenizer,
"student_model": student_model,
"max_iterations": 3,
"teacher_backend": "external", # Set to "local" to use the local teacher model.
"teacher_api_url": "https://your-teacher-api-endpoint.com/complete", # Replace with your endpoint.
"teacher_api_key": "your-api-key-if-required"
}
)
logger.info("Starting training...")
train_results = trainer.train()
logger.success("Training completed successfully!")
logger.info(f"Training metrics: {train_results}")
# Save the fine-tuned model adapter.
adapter_path = f"adapters/grpo_teacher_{int(time.time())}"
trainer.save_model(adapter_path)
logger.info(f"Adapter saved to {adapter_path}")
# (Optional) Push to HuggingFace Hub.
if training_args.push_to_hub:
trainer.push_to_hub()
logger.success("Model pushed to HuggingFace Hub successfully!")
# Inference: Test on a sample prompt.
sample_prompt = tokenizer.apply_chat_template([
{"role": "user", "content": "How many r's are in strawberry?"},
], tokenize=False, add_generation_prompt=True)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1024,
)
base_output = student_model.fast_generate(
[sample_prompt],
sampling_params=sampling_params,
lora_request=None,
)[0].outputs[0].text
logger.info(f"Base Model Output:\n{base_output}")
# Save LoRA adapters.
student_model.save_lora("grpo_saved_lora")
logger.info("LoRA adapters saved.")
# Test fine-tuned model.
sample_prompt_ft = tokenizer.apply_chat_template([
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": "How many r's are in strawberry?"},
], tokenize=False, add_generation_prompt=True)
finetuned_output = student_model.fast_generate(
[sample_prompt_ft],
sampling_params=sampling_params,
lora_request=student_model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text
logger.info(f"Fine-tuned Model Output:\n{finetuned_output}")
except Exception as e:
logger.exception(f"Training failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment