-
Star
(1,229)
You must be signed in to star a gist -
Fork
(379)
You must be signed in to fork a gist
-
-
Save willccbb/4676755236bb08cab5f4e54a0475d6fb to your computer and use it in GitHub Desktop.
# train_grpo.py | |
# | |
# See https://github.com/willccbb/verifiers for ongoing developments | |
# | |
""" | |
citation: | |
@misc{brown2025grpodemo, | |
title={Granular Format Rewards for Eliciting Mathematical Reasoning Capabilities in Small Language Models}, | |
author={Brown, William}, | |
howpublished={\url{https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb}}, | |
date = {2025-01-25}, | |
note = {GitHub Gist} | |
} | |
""" | |
import re | |
import torch | |
from datasets import load_dataset, Dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import LoraConfig | |
from trl import GRPOConfig, GRPOTrainer | |
# Load and prep dataset | |
SYSTEM_PROMPT = """ | |
Respond in the following format: | |
<reasoning> | |
... | |
</reasoning> | |
<answer> | |
... | |
</answer> | |
""" | |
XML_COT_FORMAT = """\ | |
<reasoning> | |
{reasoning} | |
</reasoning> | |
<answer> | |
{answer} | |
</answer> | |
""" | |
def extract_xml_answer(text: str) -> str: | |
answer = text.split("<answer>")[-1] | |
answer = answer.split("</answer>")[0] | |
return answer.strip() | |
def extract_hash_answer(text: str) -> str | None: | |
if "####" not in text: | |
return None | |
return text.split("####")[1].strip().replace(",", "").replace("$", "") | |
# uncomment middle messages for 1-shot prompting | |
def get_gsm8k_questions(split = "train") -> Dataset: | |
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore | |
data = data.map(lambda x: { # type: ignore | |
'prompt': [ | |
{'role': 'system', 'content': SYSTEM_PROMPT}, | |
#{'role': 'user', 'content': 'What is the largest single-digit prime number?'}, | |
#{'role': 'assistant', 'content': XML_COT_FORMAT.format( | |
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.", | |
# answer="7" | |
#)}, | |
{'role': 'user', 'content': x['question']} | |
], | |
'answer': extract_hash_answer(x['answer']) | |
}) # type: ignore | |
return data # type: ignore | |
dataset = get_gsm8k_questions() | |
# Reward functions | |
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: | |
responses = [completion[0]['content'] for completion in completions] | |
q = prompts[0][-1]['content'] | |
extracted_responses = [extract_xml_answer(r) for r in responses] | |
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") | |
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] | |
def int_reward_func(completions, **kwargs) -> list[float]: | |
responses = [completion[0]['content'] for completion in completions] | |
extracted_responses = [extract_xml_answer(r) for r in responses] | |
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] | |
def strict_format_reward_func(completions, **kwargs) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" | |
responses = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] | |
return [0.5 if match else 0.0 for match in matches] | |
def soft_format_reward_func(completions, **kwargs) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" | |
responses = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] | |
return [0.5 if match else 0.0 for match in matches] | |
def count_xml(text) -> float: | |
count = 0.0 | |
if text.count("<reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n</reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n<answer>\n") == 1: | |
count += 0.125 | |
count -= len(text.split("\n</answer>\n")[-1])*0.001 | |
if text.count("\n</answer>") == 1: | |
count += 0.125 | |
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001 | |
return count | |
def xmlcount_reward_func(completions, **kwargs) -> list[float]: | |
contents = [completion[0]["content"] for completion in completions] | |
return [count_xml(c) for c in contents] | |
#model_name = "meta-llama/Llama-3.2-1B-Instruct" | |
model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
if "Llama" in model_name: | |
output_dir = "outputs/Llama-1B-GRPO" | |
run_name = "Llama-1B-GRPO-gsm8k" | |
else: | |
output_dir="outputs/Qwen-1.5B-GRPO" | |
run_name="Qwen-1.5B-GRPO-gsm8k" | |
training_args = GRPOConfig( | |
output_dir=output_dir, | |
run_name=run_name, | |
learning_rate=5e-6, | |
adam_beta1 = 0.9, | |
adam_beta2 = 0.99, | |
weight_decay = 0.1, | |
warmup_ratio = 0.1, | |
lr_scheduler_type='cosine', | |
logging_steps=1, | |
bf16=True, | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=4, | |
num_generations=16, | |
max_prompt_length=256, | |
max_completion_length=786, | |
num_train_epochs=1, | |
save_steps=100, | |
max_grad_norm=0.1, | |
report_to="wandb", | |
log_on_each_node=False, | |
) | |
peft_config = LoraConfig( | |
r=16, | |
lora_alpha=64, | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], | |
task_type="CAUSAL_LM", | |
lora_dropout=0.05, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2", | |
device_map=None | |
).to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
# use peft at your own risk; not working for me with multi-GPU training | |
trainer = GRPOTrainer( | |
model=model, | |
processing_class=tokenizer, | |
reward_funcs=[ | |
xmlcount_reward_func, | |
soft_format_reward_func, | |
strict_format_reward_func, | |
int_reward_func, | |
correctness_reward_func], | |
args=training_args, | |
train_dataset=dataset, | |
#peft_config=peft_config | |
) | |
trainer.train() |
When training on the GPU with qwen model, I encountered the error: " probability tensor contains either
inf
,nan
or element < 0"
Hi @fsxbhyy, did you load the model in torch.bfloat16
? I used to encounter such issue when I loaded models in torch.float16
instead of bfloat
. I guess float16
in this context leads to numerical instability, leading to NaN probs. Hope this helps!
I got the same problem. I trained 7B with batch_size == 1, but it just keep reporting oom.
@harrywoo @Tuziking I had the same problem. I then noticed that these values are actually huge for most cases:
max_prompt_length=256,
max_completion_length=786,
786 generated tokens to process per generation requires a lot of memory, especially if your group size is large. Try to set this to 150 or 250 and see if it reduces memory usage. Hope this helps!
After some tuning and training on gsm8k train set (7.47k examples). Model after GRPO scores
51%~
on gsm8k test setQwen2.5-0.5B
vs base model:41.6%
as reported in their paper.在 gsm8k 火车集(7.47k 示例)上进行了一些调整和训练后。GRPO 后的模型在 gsm8k 测试集Qwen2.5-0.5B
上得分为51%~
,而基本模型:41.6%,
如他们的论文所示。Changes: 变化:
- Tune beta to 0.01将 beta 调整为 0.01
- LR 2e-6
- num_generations = 8
- 0.07 warmup, cosine, no WD0.07 预热,余弦,无 WD
- 1x4x8devices (32 total batch size)1x4x8 设备(总共 32 个批量大小)
- max completion length 512最大完成长度 512
- use only a system prompt仅使用系统提示符
- evaluated with greedy decoding on vllm在 VLLM 上使用贪婪解码进行评估
Are you using Qwen-2.5-0.5B-Instruct as your base model? I noticed in Table 10 of the Qwen2.5-technical-report that Qwen-2.5-0.5B-Instruct scores 49.6 on GSM8K, and you mentioned your trained model achieved ~51%. From this perspective, it seems there wasn't a significant performance improvement. Please correct me if my assessment is wrong.
To be honest, I tried using Qwen2.5-1.5B-Instruct as the base model to train Qwen-1.5B-GRPO, and its performance on GSM8K was 73.24
, which is almost identical to what's reported in the Qwen2.5 technical report. However, I did notice that the training brought format-related benefits. At the beginning of training, the model struggled to follow the output format required in the SYSTEM_PROMPT (<reasoning>...</reasoning><answer>...</answer>
), but after training, the model could follow this format almost perfectly. This indicates that the training did bring certain benefits—but in my experiment, the improvements were primarily in formatting rather than solving ability. Do you have any insights on this?
Does it also work on smaller models like >3B params model?