Last active
February 17, 2025 01:53
-
-
Save grahama1970/2f3e625c4c91eb69ccc024290338ec8a to your computer and use it in GitHub Desktop.
Student-Teacher-GRPO-Proof-of-Concept
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Enhanced Unsloth GRPO with Student-Teacher Reward Mechanism | |
------------------------------------------------------------ | |
This script extends Unsloth's GRPO by implementing a novel student-teacher reward | |
mechanism for improved reasoning chains. | |
Key Components: | |
1. Student Model (Unsloth GRPO): | |
- Generates initial reasoning and answers. | |
- Incorporates teacher hints through iterative "Aha!" moments. | |
- Uses sampling (not argmax) for diversity. | |
2. Teacher Model (via LiteLLM or locally): | |
- Provides subtle hints (without revealing the final answer). | |
- Can be invoked locally or via an external API. | |
- Teacher hints are tokenized to ensure compatibility. | |
3. Reward Mechanism: | |
- Base reward for correct answers (REWARD_BASE = 2.0). | |
- Quadratic penalty for extra iterations with a coefficient (PENALTY_COEFFICIENT = 0.05). | |
- Minimum reward floor (0.5). | |
Usage: | |
1. Configure models (set teacher_model_id and teacher_api_url if needed). | |
2. Run: python unsloth_student_teacher_training.py | |
3. Monitor logs in "training_{time}.log" | |
Requirements: | |
- Python 3.10+ | |
- Unsloth, LiteLLM, VLLM, Loguru, Requests, Tenacity, Datasets, PyTorch with CUDA. | |
Author: Graham Anderson | |
Date: 2025-02-17 | |
License: Apache 2.0 | |
""" | |
import os | |
import sys | |
import re | |
import time | |
import torch | |
from datasets import load_dataset | |
from loguru import logger | |
from tenacity import ( | |
retry, | |
stop_after_attempt, | |
wait_exponential, | |
retry_if_exception_type, | |
) | |
from vllm import SamplingParams | |
from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported | |
from trl import GRPOConfig, GRPOTrainer | |
import requests | |
import asyncio | |
from typing import List | |
# --- Configuration Constants --- | |
REWARD_BASE = 2.0 | |
PENALTY_COEFFICIENT = 0.05 # Lower than 0.1 to avoid over-penalizing multi-step reasoning | |
MIN_REWARD = 0.5 | |
DEBUG_CONFIG = { | |
"ENABLE_DEBUG": os.getenv("ENABLE_DEBUG", "false").lower() == "true", | |
"DEBUG_SAMPLES": int(os.getenv("DEBUG_SAMPLES", "2")), | |
"LOG_LEVEL": os.getenv("LOG_LEVEL", "INFO"), | |
} | |
# Configure Loguru logger. | |
logger.add( | |
"training_{time}.log", | |
rotation="100 MB", | |
retention="10 days", | |
level=DEBUG_CONFIG["LOG_LEVEL"], | |
backtrace=True, | |
diagnose=True, | |
enqueue=True, | |
) | |
# Apply the GRPO patch. | |
PatchFastRL("GRPO", FastLanguageModel) | |
# ---------------------- Helper Functions ---------------------- | |
def extract_xml_answer(text: str) -> str: | |
"""Extract the answer from XML-formatted text.""" | |
try: | |
answer = text.split("<answer>")[-1] | |
answer = answer.split("</answer>")[0] | |
return answer.strip() | |
except Exception as e: | |
logger.error(f"XML parsing error: {e} | Text snippet: {text[:100]}...") | |
return "" | |
def _calculate_reward(iterations: int, is_correct: bool) -> float: | |
""" | |
Calculate reward based on iterations. | |
Uses a quadratic penalty with a lower coefficient to not over-penalize multi-step reasoning. | |
""" | |
base = REWARD_BASE if is_correct else 1.0 | |
penalty = (iterations ** 2) * PENALTY_COEFFICIENT | |
return max(base - penalty, MIN_REWARD) | |
# Async retry decorator for teacher hints using LiteLLM. | |
@retry( | |
stop=stop_after_attempt(3), | |
wait=wait_exponential(multiplier=1, min=4, max=10), | |
retry=retry_if_exception_type(Exception), | |
before_sleep=lambda retry_state: logger.warning( | |
f"Retrying teacher hint (attempt {retry_state.attempt_number})" | |
) | |
) | |
async def get_single_teacher_hint(student_ans: str, teacher_model) -> str: | |
""" | |
Fetch a single teacher hint asynchronously using LiteLLM's API. | |
Returns a hint string starting with "Aha!". | |
""" | |
try: | |
from litellm import acompletion | |
response = await acompletion( | |
model=teacher_model, | |
messages=[{ | |
"role": "user", | |
"content": ( | |
f"Student answer: '{student_ans}'. " | |
"Provide a subtle hint (do not reveal the final answer) " | |
"to guide improvement. Begin with 'Aha!'." | |
) | |
}], | |
temperature=0.8, | |
top_p=0.95, | |
max_tokens=50 | |
) | |
hint = response.choices[0].message.content.strip() | |
if not hint.startswith("Aha!"): | |
hint = f"Aha! {hint}" | |
return hint | |
except Exception as e: | |
logger.error(f"Teacher hint generation failed: {e}") | |
raise | |
async def get_batch_teacher_hints(student_answers: list[str], teacher_model) -> list[str]: | |
""" | |
Fetch teacher hints for a batch of student answers concurrently. | |
""" | |
tasks = [get_single_teacher_hint(ans, teacher_model) for ans in student_answers] | |
hints = await asyncio.gather(*tasks, return_exceptions=True) | |
processed = [] | |
for idx, hint in enumerate(hints): | |
if isinstance(hint, Exception): | |
logger.error(f"Hint {idx} failed: {hint}") | |
processed.append("Aha! Please reconsider your approach.") | |
else: | |
processed.append(hint) | |
logger.debug(f"Obtained {len(processed)} teacher hints.") | |
return processed | |
class TeacherHintManager: | |
def __init__(self, teacher_model): | |
self.loop = asyncio.new_event_loop() | |
self.teacher_model = teacher_model | |
def get_hints(self, student_answers: list[str]) -> list[str]: | |
return self.loop.run_until_complete(get_batch_teacher_hints(student_answers, self.teacher_model)) | |
def tokenize_teacher_hint(hint: str, teacher_tokenizer) -> str: | |
""" | |
Ensures the teacher hint is tokenized and decoded with the teacher tokenizer, | |
to align its format with the student model. | |
""" | |
try: | |
tokens = teacher_tokenizer.encode(hint, add_special_tokens=False) | |
decoded_hint = teacher_tokenizer.decode(tokens, skip_special_tokens=True) | |
return decoded_hint.strip() | |
except Exception as e: | |
logger.error(f"Error tokenizing teacher hint: {e}") | |
return hint | |
# ---------------------- Reward Functions ---------------------- | |
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: | |
responses = [comp[0]['content'] for comp in completions] | |
extracted = [extract_xml_answer(r) for r in responses] | |
return [2.0 if s.strip() == a.strip() else 0.0 for s, a in zip(extracted, answer)] | |
def int_reward_func(completions, **kwargs) -> list[float]: | |
responses = [comp[0]['content'] for comp in completions] | |
extracted = [extract_xml_answer(r) for r in responses] | |
return [0.5 if s.strip().isdigit() else 0.0 for s in extracted] | |
def strict_format_reward_func(completions, **kwargs) -> list[float]: | |
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" | |
responses = [comp[0]["content"] for comp in completions] | |
matches = [re.match(pattern, r) for r in responses] | |
return [0.5 if m else 0.0 for m in matches] | |
def soft_format_reward_func(completions, **kwargs) -> list[float]: | |
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" | |
responses = [comp[0]["content"] for comp in completions] | |
matches = [re.match(pattern, r) for r in responses] | |
return [0.5 if m else 0.0 for m in matches] | |
def count_xml(text: str) -> float: | |
count = 0.0 | |
if text.count("<reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n</reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n<answer>\n") == 1: | |
count += 0.125 | |
count -= len(text.split("\n</answer>\n")[-1]) * 0.001 | |
if text.count("\n</answer>") == 1: | |
count += 0.125 | |
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001 | |
return count | |
def xmlcount_reward_func(completions, **kwargs) -> list[float]: | |
contents = [comp[0]["content"] for comp in completions] | |
return [count_xml(c) for c in contents] | |
def iterative_teacher_hint_reward_func( | |
prompts, | |
completions, | |
answer, | |
teacher_model, | |
teacher_tokenizer, | |
student_model, | |
max_iterations=3, | |
**kwargs | |
) -> list[float]: | |
""" | |
Iteratively evaluates the student's answer and, if incorrect, obtains a teacher hint. | |
The teacher backend can be "local" or "external" (using LiteLLM). | |
Teacher hints are tokenized to ensure compatibility. | |
The teacher hint is appended to the student's chain-of-thought, and the student model re-generates output. | |
A quadratic penalty (with a milder coefficient) is applied per extra iteration. | |
""" | |
rewards = [] | |
teacher_backend = kwargs.get("teacher_backend", "local") | |
teacher_api_url = kwargs.get("teacher_api_url", None) | |
teacher_api_key = kwargs.get("teacher_api_key", None) | |
# Collect initial student answers. | |
student_answers = [extract_xml_answer(comp[0]['content']) for comp in completions] | |
hint_manager = TeacherHintManager(teacher_model) | |
try: | |
teacher_hints = hint_manager.get_hints(student_answers) | |
except Exception as e: | |
logger.error(f"Failed to get initial teacher hints: {e}") | |
teacher_hints = ["Aha! Please reconsider your approach."] * len(student_answers) | |
for idx, (comp, gt, init_hint) in enumerate(zip(completions, answer, teacher_hints)): | |
try: | |
student_output = comp[0]['content'] | |
iterations = 0 | |
reasoning_chain: List[str] = [] | |
while iterations < max_iterations: | |
student_ans = extract_xml_answer(student_output) | |
if student_ans.strip() == gt.strip(): | |
logger.info(f"Sample {idx}: Correct answer achieved after {iterations} iterations.") | |
break | |
# Prepare teacher prompt. | |
teacher_prompt = ( | |
f"Student answer: '{student_ans}'. " | |
"Provide a subtle hint (do not reveal the final answer) to guide improvement. " | |
"Begin your hint with 'Aha!'." | |
) | |
if teacher_backend == "local": | |
try: | |
hint_raw = teacher_model.fast_generate( | |
[teacher_prompt], | |
sampling_params=SamplingParams(temperature=0.8, top_p=0.95, max_tokens=50), | |
lora_request=None | |
)[0].outputs[0].text.strip() | |
except Exception as e: | |
logger.error(f"Local teacher hint generation error (sample {idx}): {e}") | |
hint_raw = "Aha! Please reconsider your approach." | |
elif teacher_backend == "external" and teacher_api_url is not None: | |
hint_raw = call_teacher_hint_litelm(teacher_prompt, teacher_api_url, teacher_api_key) | |
else: | |
hint_raw = "Aha! Please reconsider your approach." | |
# Tokenize the teacher hint to ensure compatibility. | |
teacher_hint = tokenize_teacher_hint(hint_raw, teacher_tokenizer) | |
reasoning_chain.append(teacher_hint) | |
logger.debug(f"Sample {idx} - Iteration {iterations+1} teacher hint: {teacher_hint}") | |
# Append the teacher's hint to the student output. | |
student_output = student_output + "\n" + teacher_hint | |
# Re-generate student output using sampling (not argmax) for diversity. | |
try: | |
new_output = student_model.fast_generate( | |
[student_output], | |
sampling_params=SamplingParams(temperature=0.8, top_p=0.95, max_tokens=200), | |
lora_request=None | |
)[0].outputs[0].text | |
except Exception as e: | |
logger.error(f"Student re-generation error (sample {idx}, iteration {iterations+1}): {e}") | |
new_output = student_output | |
student_output = new_output | |
iterations += 1 | |
final_ans = extract_xml_answer(student_output) | |
is_correct = final_ans.strip() == gt.strip() | |
reward = _calculate_reward(iterations, is_correct) | |
logger.info(f"Sample {idx}: Final answer '{final_ans}' | Reward: {reward:.2f} after {iterations} iterations") | |
logger.debug("Reasoning chain:\n" + "\n".join(reasoning_chain)) | |
rewards.append(reward) | |
except Exception as e: | |
logger.exception(f"Error processing sample {idx}: {e}") | |
rewards.append(MIN_REWARD) | |
torch.cuda.empty_cache() | |
return rewards | |
# ---------------------- Data Preparation ---------------------- | |
SYSTEM_PROMPT = """ | |
Respond in the following format: | |
<reasoning> | |
... | |
</reasoning> | |
<answer> | |
... | |
</answer> | |
""" | |
def get_gsm8k_questions(split="train") -> "Dataset": | |
"""Load and format the GSM8K dataset.""" | |
try: | |
data = load_dataset('openai/gsm8k', 'main')[split] | |
data = data.map(lambda x: { | |
'prompt': [ | |
{'role': 'system', 'content': SYSTEM_PROMPT}, | |
{'role': 'user', 'content': x['question']} | |
], | |
'answer': x['answer'].split("####")[1].strip() | |
}) | |
logger.info(f"Loaded {len(data)} examples from GSM8K ({split}) split.") | |
return data | |
except Exception as e: | |
logger.error(f"Failed to load GSM8K dataset: {e}") | |
raise | |
dataset = get_gsm8k_questions() | |
# ---------------------- GRPO Training Configuration ---------------------- | |
training_args = GRPOConfig( | |
use_vllm=True, | |
learning_rate=5e-6, | |
adam_beta1=0.9, | |
adam_beta2=0.99, | |
weight_decay=0.1, | |
warmup_ratio=0.1, | |
lr_scheduler_type="cosine", | |
optim="adamw_8bit", | |
logging_steps=1, | |
bf16=is_bfloat16_supported(), | |
fp16=not is_bfloat16_supported(), | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=1, | |
num_generations=8, | |
max_prompt_length=256, | |
max_completion_length=200, | |
max_steps=250, | |
save_steps=250, | |
max_grad_norm=0.1, | |
report_to="none", | |
output_dir="outputs", | |
) | |
# ---------------------- Main Function ---------------------- | |
def main(): | |
try: | |
logger.info("Starting GRPO training with iterative teacher hints...") | |
# Load dataset and optionally restrict samples in debug mode. | |
dataset_full = get_gsm8k_questions() | |
if DEBUG_CONFIG["ENABLE_DEBUG"]: | |
dataset_full = dataset_full.select(range(DEBUG_CONFIG["DEBUG_SAMPLES"])) | |
logger.warning(f"Debug mode enabled: using {DEBUG_CONFIG['DEBUG_SAMPLES']} samples.") | |
# Initialize GRPOTrainer with our iterative teacher hint reward function. | |
trainer = GRPOTrainer( | |
model=student_model, | |
processing_class=tokenizer, | |
reward_funcs=[ | |
xmlcount_reward_func, | |
soft_format_reward_func, | |
strict_format_reward_func, | |
int_reward_func, | |
correctness_reward_func, | |
iterative_teacher_hint_reward_func, # Our new iterative reward function. | |
], | |
args=training_args, | |
train_dataset=dataset_full, | |
eval_dataset=dataset_full, | |
reward_kwargs={ | |
"teacher_model": teacher_model, | |
"teacher_tokenizer": teacher_tokenizer, | |
"student_model": student_model, | |
"max_iterations": 3, | |
"teacher_backend": "external", # Set to "local" to use the local teacher model. | |
"teacher_api_url": "https://your-teacher-api-endpoint.com/complete", # Replace with your endpoint. | |
"teacher_api_key": "your-api-key-if-required" | |
} | |
) | |
logger.info("Starting training...") | |
train_results = trainer.train() | |
logger.success("Training completed successfully!") | |
logger.info(f"Training metrics: {train_results}") | |
# Save the fine-tuned model adapter. | |
adapter_path = f"adapters/grpo_teacher_{int(time.time())}" | |
trainer.save_model(adapter_path) | |
logger.info(f"Adapter saved to {adapter_path}") | |
# (Optional) Push to HuggingFace Hub. | |
if training_args.push_to_hub: | |
trainer.push_to_hub() | |
logger.success("Model pushed to HuggingFace Hub successfully!") | |
# Inference: Test on a sample prompt. | |
sample_prompt = tokenizer.apply_chat_template([ | |
{"role": "user", "content": "How many r's are in strawberry?"}, | |
], tokenize=False, add_generation_prompt=True) | |
sampling_params = SamplingParams( | |
temperature=0.8, | |
top_p=0.95, | |
max_tokens=1024, | |
) | |
base_output = student_model.fast_generate( | |
[sample_prompt], | |
sampling_params=sampling_params, | |
lora_request=None, | |
)[0].outputs[0].text | |
logger.info(f"Base Model Output:\n{base_output}") | |
# Save LoRA adapters. | |
student_model.save_lora("grpo_saved_lora") | |
logger.info("LoRA adapters saved.") | |
# Test fine-tuned model. | |
sample_prompt_ft = tokenizer.apply_chat_template([ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": "How many r's are in strawberry?"}, | |
], tokenize=False, add_generation_prompt=True) | |
finetuned_output = student_model.fast_generate( | |
[sample_prompt_ft], | |
sampling_params=sampling_params, | |
lora_request=student_model.load_lora("grpo_saved_lora"), | |
)[0].outputs[0].text | |
logger.info(f"Fine-tuned Model Output:\n{finetuned_output}") | |
except Exception as e: | |
logger.exception(f"Training failed: {e}") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment