Forked from bigsnarfdude/grpo_experiment_found_online.py
Created
January 31, 2025 15:08
-
-
Save ishan-marikar/877fd9f03a5cdbaeb71a24cff2357068 to your computer and use it in GitHub Desktop.
grpo_experiment_found_online.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import os | |
from dataclasses import dataclass | |
from datetime import datetime | |
import logging | |
import os | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
import random | |
import re | |
import torch | |
from transformers.trainer_utils import get_last_checkpoint | |
from transformers import AutoTokenizer | |
from datasets import load_dataset | |
from trl import GRPOConfig, GRPOTrainer, get_peft_config, ModelConfig, TrlParser | |
######################## | |
# Custom dataclasses | |
######################## | |
@dataclass | |
class ScriptArguments: | |
dataset_id_or_path: str = "Jiayi-Pan/Countdown-Tasks-3to4" | |
dataset_splits: str = "train" | |
tokenizer_name_or_path: str = None | |
######################## | |
# Setup logging | |
######################## | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
handler = logging.StreamHandler() | |
handler.setFormatter( | |
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
) | |
logger.addHandler(handler) | |
######################## | |
# Helper functions | |
######################## | |
def format_reward_func(completions, target, **kwargs): | |
""" | |
Format: <think>...</think><answer>...</answer> | |
Args: | |
completions (list[str]): Generated outputs | |
target (list[str]): Expected answers | |
Returns: | |
list[float]: Reward scores | |
""" | |
rewards = [] | |
for completion, gt in zip(completions, target): | |
try: | |
# add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex | |
completion = "<think>" + completion | |
if random.random() < 0.1: # 1% chance to write samples into a file | |
os.makedirs("completion_samples", exist_ok=True) | |
log_file = os.path.join("completion_samples", "completion_samples.txt") | |
with open(log_file, "a") as f: | |
f.write(f"\n\n==============\n") | |
f.write(completion) | |
# Check if the format is correct | |
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$" | |
match = re.search(regex, completion, re.DOTALL) | |
# if the format is not correct, reward is 0 | |
if match is None or len(match.groups()) != 2: | |
rewards.append(0.0) | |
else: | |
rewards.append(1.0) | |
except Exception: | |
rewards.append(0.0) | |
return rewards | |
def equation_reward_func(completions, target, nums, **kwargs): | |
""" | |
Evaluates completions based on: | |
2. Mathematical correctness of the answer | |
Args: | |
completions (list[str]): Generated outputs | |
target (list[str]): Expected answers | |
nums (list[str]): Available numbers | |
Returns: | |
list[float]: Reward scores | |
""" | |
rewards = [] | |
for completion, gt, numbers in zip(completions, target, nums): | |
try: | |
# add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex | |
completion = "<think>" + completion | |
# Check if the format is correct | |
match = re.search(r"<answer>(.*?)<\/answer>", completion) | |
if match is None: | |
rewards.append(0.0) | |
continue | |
# Extract the "answer" part from the completion | |
equation = match.group(1).strip() | |
# Extract all numbers from the equation | |
used_numbers = [int(n) for n in re.findall(r'\d+', equation)] | |
# Check if all numbers are used exactly once | |
if sorted(used_numbers) != sorted(numbers): | |
rewards.append(0.0) | |
continue | |
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace | |
allowed_pattern = r'^[\d+\-*/().\s]+$' | |
if not re.match(allowed_pattern, equation): | |
rewards.append(0.0) | |
continue | |
# Evaluate the equation with restricted globals and locals | |
result = eval(equation, {"__builtins__": None}, {}) | |
# Check if the equation is correct and matches the ground truth | |
if abs(float(result) - float(gt)) < 1e-5: | |
rewards.append(1.0) | |
if random.random() < 0.10: # 10% chance to write fully successful samples into a file | |
os.makedirs("completion_samples", exist_ok=True) | |
log_file = os.path.join("completion_samples", "success_completion_samples.txt") | |
with open(log_file, "a") as f: | |
f.write(f"\n\n==============\n") | |
f.write(completion) | |
else: | |
rewards.append(0.0) | |
except Exception: | |
# If evaluation fails, reward is 0 | |
rewards.append(0.0) | |
return rewards | |
def get_checkpoint(training_args: GRPOConfig): | |
last_checkpoint = None | |
if os.path.isdir(training_args.output_dir): | |
last_checkpoint = get_last_checkpoint(training_args.output_dir) | |
return last_checkpoint | |
def grpo_function( | |
model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig | |
): | |
######################### | |
# Log parameters | |
######################### | |
logger.info(f"Model parameters {model_args}") | |
logger.info(f"Training/evaluation parameters {training_args}") | |
################ | |
# Load tokenizer | |
################ | |
tokenizer = AutoTokenizer.from_pretrained( | |
( | |
script_args.tokenizer_name_or_path | |
if script_args.tokenizer_name_or_path | |
else model_args.model_name_or_path | |
), | |
revision=model_args.model_revision, | |
trust_remote_code=model_args.trust_remote_code, | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
############### | |
# Load datasets | |
############### | |
# Load dataset from Hugging Face Hub | |
dataset = load_dataset(script_args.dataset_id_or_path, split=script_args.dataset_splits) | |
# select a random subset of 50k samples | |
dataset = dataset.shuffle(seed=42).select(range(50000)) | |
##################### | |
# Prepare and format dataset | |
##################### | |
# gemerate r1 prompt with a prefix for the model to already start with the thinking process | |
def generate_r1_prompt(numbers, target): | |
r1_prefix = [{ | |
"role": "system", | |
"content": "You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer." | |
}, | |
{ | |
"role": "user", | |
"content": f"Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) one or multiple times but each number can only be used once. Show your work in <think> </think> tags. And return the final equation in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>. Think step by step inside <think> tags." | |
}, | |
{ | |
"role": "assistant", | |
"content": "Let me solve this step by step.\n<think>" | |
}] | |
return {"prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True), "target": target, "nums": numbers} | |
# convert our dataset to the r1 prompt | |
dataset = dataset.map(lambda x: generate_r1_prompt(x["nums"], x["target"])) | |
# split the dataset into train and test | |
train_test_split = dataset.train_test_split(test_size=0.1) | |
train_dataset = train_test_split["train"] | |
test_dataset = train_test_split["test"] | |
######################### | |
# Instantiate DPO trainer | |
######################### | |
trainer = GRPOTrainer( | |
model=model_args.model_name_or_path, | |
reward_funcs=[format_reward_func, equation_reward_func], | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=test_dataset, | |
peft_config=get_peft_config(model_args), | |
) | |
############### | |
# Training loop | |
############### | |
# Check for last checkpoint | |
last_checkpoint = get_checkpoint(training_args) | |
if last_checkpoint is not None and training_args.resume_from_checkpoint is None: | |
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.") | |
# Train the model | |
logger.info( | |
f'*** Starting training {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} for {training_args.num_train_epochs} epochs***' | |
) | |
train_result = trainer.train(resume_from_checkpoint=last_checkpoint) | |
# Log and save metrics | |
metrics = train_result.metrics | |
metrics["train_samples"] = len(train_dataset) | |
trainer.log_metrics("train", metrics) | |
trainer.save_metrics("train", metrics) | |
trainer.save_state() | |
logger.info("*** Training complete ***") | |
################################## | |
# Save model and create model card | |
################################## | |
logger.info("*** Save model ***") | |
trainer.model.config.use_cache = True | |
trainer.save_model(training_args.output_dir) | |
logger.info(f"Model saved to {training_args.output_dir}") | |
training_args.distributed_state.wait_for_everyone() # wait for all processes to load | |
tokenizer.save_pretrained(training_args.output_dir) | |
logger.info(f"Tokenizer saved to {training_args.output_dir}") | |
# Save everything else on main process | |
if trainer.accelerator.is_main_process: | |
trainer.create_model_card({"tags": ["rl","grpo", "tutorial", "philschmid"]}) | |
# push to hub if needed | |
if training_args.push_to_hub is True: | |
logger.info("Pushing to hub...") | |
trainer.push_to_hub() | |
logger.info("*** Training complete! ***") | |
def main(): | |
parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig)) | |
model_args, script_args, training_args = parser.parse_args_and_config() | |
# Run the main training loop | |
grpo_function(model_args, script_args, training_args) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment