Skip to content

Instantly share code, notes, and snippets.

@andrewor14
Created November 11, 2025 20:54
Show Gist options
  • Select an option

  • Save andrewor14/5b85119fae46845d07b608d420907423 to your computer and use it in GitHub Desktop.

Select an option

Save andrewor14/5b85119fae46845d07b608d420907423 to your computer and use it in GitHub Desktop.
Unsloth FP8 + GRPO test script
# Modeled after https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb
from unsloth import FastLanguageModel
import gc
import os
import re
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
from vllm import SamplingParams
import pandas as pd
import numpy as np
import torch
# ==============
# Model setup |
# ==============
load_in_fp8 = os.getenv("LOAD_IN_FP8", "true").lower()
load_in_fp8 = load_in_fp8 == "true" or load_in_fp8 == "1"
max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen3-8B-Base",
max_seq_length = max_seq_length,
load_in_4bit = False,
fast_inference = True,
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.9,
load_in_fp8 = load_in_fp8,
)
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank,
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = lora_rank*2,
use_gradient_checkpointing = "unsloth",
random_state = 3407,
)
# ================
# Chat template |
# ================
reasoning_start = "<start_working_out>" # Acts as <think>
reasoning_end = "<end_working_out>" # Acts as </think>
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"
system_prompt = \
f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""
system_prompt
chat_template = \
"{% if messages[0]['role'] == 'system' %}"\
"{{ messages[0]['content'] + eos_token }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ '{system_prompt}' + eos_token }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"{{ message['content'] }}"\
"{% elif message['role'] == 'assistant' %}"\
"{{ message['content'] + eos_token }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}{{ '{reasoning_start}' }}"\
"{% endif %}"
# Replace with out specific template:
chat_template = chat_template\
.replace("'{system_prompt}'", f"'{system_prompt}'")\
.replace("'{reasoning_start}'", f"'{reasoning_start}'")
tokenizer.chat_template = chat_template
tokenizer.apply_chat_template([
{"role" : "user", "content" : "What is 1+1?"},
{"role" : "assistant", "content" : f"{reasoning_start}I think it's 2.{reasoning_end}{solution_start}2{solution_end}"},
{"role" : "user", "content" : "What is 2+2?"},
], tokenize = False, add_generation_prompt = True)
# =============
# Data setup |
# =============
dataset = load_dataset("unsloth/OpenMathReasoning-mini", split = "cot")
dataset = dataset.to_pandas()[
["expected_answer", "problem", "generated_solution"]
]
# Try converting to number - if not, replace with NaN
is_number = pd.to_numeric(pd.Series(dataset["expected_answer"]), errors = "coerce").notnull()
# Select only numbers
dataset = dataset.iloc[np.where(is_number)[0]]
def format_dataset(x):
expected_answer = x["expected_answer"]
problem = x["problem"]
# Remove generated <think> and </think>
thoughts = x["generated_solution"]
thoughts = thoughts.replace("<think>", "").replace("</think>", "")
# Strip newlines on left and right
thoughts = thoughts.strip()
# Add our custom formatting
final_prompt = \
reasoning_start + thoughts + reasoning_end + \
solution_start + expected_answer + solution_end
return [
{"role" : "system", "content" : system_prompt},
{"role" : "user", "content" : problem},
{"role" : "assistant", "content" : final_prompt},
]
dataset["Messages"] = dataset.apply(format_dataset, axis = 1)
# Let's truncate the pre fine-tuning dataset to max_seq_length/2 since
# we don't want too long reasoning traces.
dataset["N"] = dataset["Messages"].apply(lambda x: len(tokenizer.apply_chat_template(x)))
dataset = dataset.loc[dataset["N"] <= max_seq_length/2].copy()
# We then tokenize the messages and convert it to a Hugging Face compatible dataset format
dataset["text"] = tokenizer.apply_chat_template(dataset["Messages"].values.tolist(), tokenize = False)
dataset = Dataset.from_pandas(dataset)
# ===============
# Pre-finetune |
# ===============
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
args = SFTConfig(
dataset_text_field = "text",
per_device_train_batch_size = 1,
gradient_accumulation_steps = 1, # Use GA to mimic batch size!
warmup_steps = 5,
num_train_epochs = 2, # Set this for 1 full training run.
learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
logging_steps = 5,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
report_to = "none", # Use this for WandB etc
),
)
# Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")
trainer_stats = trainer.train()
# Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
del dataset
torch.cuda.empty_cache()
gc.collect()
# ======================================
# Data setup again + reward functions |
# ======================================
dataset = load_dataset("open-r1/DAPO-Math-17k-Processed", "en", split = "train")
def extract_hash_answer(text):
# if "####" not in text: return None
# return text.split("####")[1].strip()
return text
dataset = dataset.map(lambda x: {
"prompt" : [
{"role": "system", "content": system_prompt},
{"role": "user", "content": x["prompt"]},
],
"answer": extract_hash_answer(x["solution"]),
})
# Add optional EOS token matching
solution_end_regex = r"</SOLUTION>[\s]{0,}" + \
"(?:" + re.escape(tokenizer.eos_token) + ")?"
match_format = re.compile(
rf"{reasoning_end}.*?"\
rf"{solution_start}(.+?){solution_end_regex}"\
rf"[\s]{{0,}}$",
flags = re.MULTILINE | re.DOTALL
)
def match_format_exactly(completions, **kwargs):
scores = []
for completion in completions:
score = 0
response = completion[0]["content"]
# Match if format is seen exactly!
if match_format.search(response) is not None: score += 3.0
scores.append(score)
return scores
def match_format_approximately(completions, **kwargs):
scores = []
for completion in completions:
score = 0
response = completion[0]["content"]
# Count how many keywords are seen - we penalize if too many!
# If we see 1, then plus some points!
# No need to reward <start_working_out> since we always prepend it!
# score += 0.5 if response.count(reasoning_start) == 1 else -1.0
score += 0.5 if response.count(reasoning_end) == 1 else -1.0
score += 0.5 if response.count(solution_start) == 1 else -1.0
score += 0.5 if response.count(solution_end) == 1 else -1.0
scores.append(score)
return scores
def check_answer(prompts, completions, answer, **kwargs):
question = prompts[0][-1]["content"]
responses = [completion[0]["content"] for completion in completions]
extracted_responses = [
guess.group(1)
if (guess := match_format.search(r)) is not None else None \
for r in responses
]
scores = []
for guess, true_answer in zip(extracted_responses, answer):
score = 0
if guess is None:
scores.append(-2.0)
continue
# Correct answer gets 5 points!
if guess == true_answer:
score += 5.0
# Match if spaces are seen, but less reward
elif guess.strip() == true_answer.strip():
score += 3.5
else:
# We also reward it if the answer is close via ratios!
# Ie if the answer is within some range, reward it!
try:
ratio = float(guess) / float(true_answer)
if ratio >= 0.9 and ratio <= 1.1: score += 2.0
elif ratio >= 0.8 and ratio <= 1.2: score += 1.5
else: score -= 2.5 # Penalize wrong answers
except:
score -= 4.5 # Penalize
scores.append(score)
return scores
match_numbers = re.compile(
solution_start + r".*?[\s]{0,}([-]?[\d\.\,]{1,})",
flags = re.MULTILINE | re.DOTALL
)
global PRINTED_TIMES
PRINTED_TIMES = 0
global PRINT_EVERY_STEPS
PRINT_EVERY_STEPS = 5
def check_numbers(prompts, completions, answer, **kwargs):
question = prompts[0][-1]["content"]
responses = [completion[0]["content"] for completion in completions]
extracted_responses = [
guess.group(1)
if (guess := match_numbers.search(r)) is not None else None \
for r in responses
]
scores = []
# Print only every few steps
global PRINTED_TIMES
global PRINT_EVERY_STEPS
if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
print(
'*'*20 + f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}"
)
PRINTED_TIMES += 1
for guess, true_answer in zip(extracted_responses, answer):
if guess is None:
scores.append(-2.5)
continue
# Convert to numbers
try:
true_answer = float(true_answer.strip())
# Remove commas like in 123,456
guess = float(guess.strip().replace(",", ""))
scores.append(3.5 if guess == true_answer else -1.5)
except:
scores.append(0)
continue
return scores
# Get the top 90% prompt length so we don't accidentally truncate them!
# Ie we'll remove the top 10% long prompts.
tokenized = dataset.map(
lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)},
batched = True,
)
tokenized = tokenized.map(lambda x: {"L" : len(x["tokens"])})
maximum_length = int(np.quantile(tokenized["L"], 0.9))
# Filter only samples smaller than 90% max length
dataset = dataset.select(np.where(np.array(tokenized["L"]) <= maximum_length)[0])
del tokenized
# ===========
# Training |
# ===========
max_prompt_length = maximum_length + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length
vllm_sampling_params = SamplingParams(
min_p = 0.1,
top_p = 1.0,
top_k = -1,
seed = 3407,
stop = [tokenizer.eos_token],
include_stop_str_in_output = True,
)
training_args = GRPOConfig(
vllm_sampling_params = vllm_sampling_params,
temperature = 1.0,
learning_rate = 5e-6,
weight_decay = 0.01,
warmup_ratio = 0.1,
lr_scheduler_type = "linear",
optim = "adamw_8bit",
logging_steps = 1,
per_device_train_batch_size = 1,
gradient_accumulation_steps = 1, # Increase to 4 for smoother training
num_generations = 4, # Decrease if out of memory
max_prompt_length = max_prompt_length,
max_completion_length = max_completion_length,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps = 100,
save_steps = 100,
report_to = "none", # Can use Weights & Biases
output_dir = "outputs",
)
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
match_format_exactly,
match_format_approximately,
check_answer,
check_numbers,
],
args = training_args,
train_dataset = dataset,
)
# Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")
trainer_stats = trainer.train()
# Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
# ============
# Inference |
# ============
text = "What is the sqrt of 101?"
sampling_params = SamplingParams(
temperature = 1.0,
top_k = 50,
max_tokens = 1024,
)
output = model.fast_generate(
[text],
sampling_params = sampling_params,
lora_request = None,
)[0].outputs[0].text
print(output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment