Created
November 11, 2025 20:54
-
-
Save andrewor14/5b85119fae46845d07b608d420907423 to your computer and use it in GitHub Desktop.
Unsloth FP8 + GRPO test script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Modeled after https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb | |
| from unsloth import FastLanguageModel | |
| import gc | |
| import os | |
| import re | |
| from datasets import load_dataset, Dataset | |
| from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer | |
| from vllm import SamplingParams | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| # ============== | |
| # Model setup | | |
| # ============== | |
| load_in_fp8 = os.getenv("LOAD_IN_FP8", "true").lower() | |
| load_in_fp8 = load_in_fp8 == "true" or load_in_fp8 == "1" | |
| max_seq_length = 2048 # Can increase for longer reasoning traces | |
| lora_rank = 32 # Larger rank = smarter, but slower | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = "unsloth/Qwen3-8B-Base", | |
| max_seq_length = max_seq_length, | |
| load_in_4bit = False, | |
| fast_inference = True, | |
| max_lora_rank = lora_rank, | |
| gpu_memory_utilization = 0.9, | |
| load_in_fp8 = load_in_fp8, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r = lora_rank, | |
| target_modules = [ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ], | |
| lora_alpha = lora_rank*2, | |
| use_gradient_checkpointing = "unsloth", | |
| random_state = 3407, | |
| ) | |
| # ================ | |
| # Chat template | | |
| # ================ | |
| reasoning_start = "<start_working_out>" # Acts as <think> | |
| reasoning_end = "<end_working_out>" # Acts as </think> | |
| solution_start = "<SOLUTION>" | |
| solution_end = "</SOLUTION>" | |
| system_prompt = \ | |
| f"""You are given a problem. | |
| Think about the problem and provide your working out. | |
| Place it between {reasoning_start} and {reasoning_end}. | |
| Then, provide your solution between {solution_start}{solution_end}""" | |
| system_prompt | |
| chat_template = \ | |
| "{% if messages[0]['role'] == 'system' %}"\ | |
| "{{ messages[0]['content'] + eos_token }}"\ | |
| "{% set loop_messages = messages[1:] %}"\ | |
| "{% else %}"\ | |
| "{{ '{system_prompt}' + eos_token }}"\ | |
| "{% set loop_messages = messages %}"\ | |
| "{% endif %}"\ | |
| "{% for message in loop_messages %}"\ | |
| "{% if message['role'] == 'user' %}"\ | |
| "{{ message['content'] }}"\ | |
| "{% elif message['role'] == 'assistant' %}"\ | |
| "{{ message['content'] + eos_token }}"\ | |
| "{% endif %}"\ | |
| "{% endfor %}"\ | |
| "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"\ | |
| "{% endif %}" | |
| # Replace with out specific template: | |
| chat_template = chat_template\ | |
| .replace("'{system_prompt}'", f"'{system_prompt}'")\ | |
| .replace("'{reasoning_start}'", f"'{reasoning_start}'") | |
| tokenizer.chat_template = chat_template | |
| tokenizer.apply_chat_template([ | |
| {"role" : "user", "content" : "What is 1+1?"}, | |
| {"role" : "assistant", "content" : f"{reasoning_start}I think it's 2.{reasoning_end}{solution_start}2{solution_end}"}, | |
| {"role" : "user", "content" : "What is 2+2?"}, | |
| ], tokenize = False, add_generation_prompt = True) | |
| # ============= | |
| # Data setup | | |
| # ============= | |
| dataset = load_dataset("unsloth/OpenMathReasoning-mini", split = "cot") | |
| dataset = dataset.to_pandas()[ | |
| ["expected_answer", "problem", "generated_solution"] | |
| ] | |
| # Try converting to number - if not, replace with NaN | |
| is_number = pd.to_numeric(pd.Series(dataset["expected_answer"]), errors = "coerce").notnull() | |
| # Select only numbers | |
| dataset = dataset.iloc[np.where(is_number)[0]] | |
| def format_dataset(x): | |
| expected_answer = x["expected_answer"] | |
| problem = x["problem"] | |
| # Remove generated <think> and </think> | |
| thoughts = x["generated_solution"] | |
| thoughts = thoughts.replace("<think>", "").replace("</think>", "") | |
| # Strip newlines on left and right | |
| thoughts = thoughts.strip() | |
| # Add our custom formatting | |
| final_prompt = \ | |
| reasoning_start + thoughts + reasoning_end + \ | |
| solution_start + expected_answer + solution_end | |
| return [ | |
| {"role" : "system", "content" : system_prompt}, | |
| {"role" : "user", "content" : problem}, | |
| {"role" : "assistant", "content" : final_prompt}, | |
| ] | |
| dataset["Messages"] = dataset.apply(format_dataset, axis = 1) | |
| # Let's truncate the pre fine-tuning dataset to max_seq_length/2 since | |
| # we don't want too long reasoning traces. | |
| dataset["N"] = dataset["Messages"].apply(lambda x: len(tokenizer.apply_chat_template(x))) | |
| dataset = dataset.loc[dataset["N"] <= max_seq_length/2].copy() | |
| # We then tokenize the messages and convert it to a Hugging Face compatible dataset format | |
| dataset["text"] = tokenizer.apply_chat_template(dataset["Messages"].values.tolist(), tokenize = False) | |
| dataset = Dataset.from_pandas(dataset) | |
| # =============== | |
| # Pre-finetune | | |
| # =============== | |
| trainer = SFTTrainer( | |
| model = model, | |
| tokenizer = tokenizer, | |
| train_dataset = dataset, | |
| args = SFTConfig( | |
| dataset_text_field = "text", | |
| per_device_train_batch_size = 1, | |
| gradient_accumulation_steps = 1, # Use GA to mimic batch size! | |
| warmup_steps = 5, | |
| num_train_epochs = 2, # Set this for 1 full training run. | |
| learning_rate = 2e-4, # Reduce to 2e-5 for long training runs | |
| logging_steps = 5, | |
| optim = "adamw_8bit", | |
| weight_decay = 0.01, | |
| lr_scheduler_type = "linear", | |
| seed = 3407, | |
| report_to = "none", # Use this for WandB etc | |
| ), | |
| ) | |
| # Show current memory stats | |
| gpu_stats = torch.cuda.get_device_properties(0) | |
| start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
| max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) | |
| print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") | |
| print(f"{start_gpu_memory} GB of memory reserved.") | |
| trainer_stats = trainer.train() | |
| # Show final memory and time stats | |
| used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
| used_memory_for_lora = round(used_memory - start_gpu_memory, 3) | |
| used_percentage = round(used_memory / max_memory * 100, 3) | |
| lora_percentage = round(used_memory_for_lora / max_memory * 100, 3) | |
| print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.") | |
| print( | |
| f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training." | |
| ) | |
| print(f"Peak reserved memory = {used_memory} GB.") | |
| print(f"Peak reserved memory for training = {used_memory_for_lora} GB.") | |
| print(f"Peak reserved memory % of max memory = {used_percentage} %.") | |
| print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.") | |
| del dataset | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # ====================================== | |
| # Data setup again + reward functions | | |
| # ====================================== | |
| dataset = load_dataset("open-r1/DAPO-Math-17k-Processed", "en", split = "train") | |
| def extract_hash_answer(text): | |
| # if "####" not in text: return None | |
| # return text.split("####")[1].strip() | |
| return text | |
| dataset = dataset.map(lambda x: { | |
| "prompt" : [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": x["prompt"]}, | |
| ], | |
| "answer": extract_hash_answer(x["solution"]), | |
| }) | |
| # Add optional EOS token matching | |
| solution_end_regex = r"</SOLUTION>[\s]{0,}" + \ | |
| "(?:" + re.escape(tokenizer.eos_token) + ")?" | |
| match_format = re.compile( | |
| rf"{reasoning_end}.*?"\ | |
| rf"{solution_start}(.+?){solution_end_regex}"\ | |
| rf"[\s]{{0,}}$", | |
| flags = re.MULTILINE | re.DOTALL | |
| ) | |
| def match_format_exactly(completions, **kwargs): | |
| scores = [] | |
| for completion in completions: | |
| score = 0 | |
| response = completion[0]["content"] | |
| # Match if format is seen exactly! | |
| if match_format.search(response) is not None: score += 3.0 | |
| scores.append(score) | |
| return scores | |
| def match_format_approximately(completions, **kwargs): | |
| scores = [] | |
| for completion in completions: | |
| score = 0 | |
| response = completion[0]["content"] | |
| # Count how many keywords are seen - we penalize if too many! | |
| # If we see 1, then plus some points! | |
| # No need to reward <start_working_out> since we always prepend it! | |
| # score += 0.5 if response.count(reasoning_start) == 1 else -1.0 | |
| score += 0.5 if response.count(reasoning_end) == 1 else -1.0 | |
| score += 0.5 if response.count(solution_start) == 1 else -1.0 | |
| score += 0.5 if response.count(solution_end) == 1 else -1.0 | |
| scores.append(score) | |
| return scores | |
| def check_answer(prompts, completions, answer, **kwargs): | |
| question = prompts[0][-1]["content"] | |
| responses = [completion[0]["content"] for completion in completions] | |
| extracted_responses = [ | |
| guess.group(1) | |
| if (guess := match_format.search(r)) is not None else None \ | |
| for r in responses | |
| ] | |
| scores = [] | |
| for guess, true_answer in zip(extracted_responses, answer): | |
| score = 0 | |
| if guess is None: | |
| scores.append(-2.0) | |
| continue | |
| # Correct answer gets 5 points! | |
| if guess == true_answer: | |
| score += 5.0 | |
| # Match if spaces are seen, but less reward | |
| elif guess.strip() == true_answer.strip(): | |
| score += 3.5 | |
| else: | |
| # We also reward it if the answer is close via ratios! | |
| # Ie if the answer is within some range, reward it! | |
| try: | |
| ratio = float(guess) / float(true_answer) | |
| if ratio >= 0.9 and ratio <= 1.1: score += 2.0 | |
| elif ratio >= 0.8 and ratio <= 1.2: score += 1.5 | |
| else: score -= 2.5 # Penalize wrong answers | |
| except: | |
| score -= 4.5 # Penalize | |
| scores.append(score) | |
| return scores | |
| match_numbers = re.compile( | |
| solution_start + r".*?[\s]{0,}([-]?[\d\.\,]{1,})", | |
| flags = re.MULTILINE | re.DOTALL | |
| ) | |
| global PRINTED_TIMES | |
| PRINTED_TIMES = 0 | |
| global PRINT_EVERY_STEPS | |
| PRINT_EVERY_STEPS = 5 | |
| def check_numbers(prompts, completions, answer, **kwargs): | |
| question = prompts[0][-1]["content"] | |
| responses = [completion[0]["content"] for completion in completions] | |
| extracted_responses = [ | |
| guess.group(1) | |
| if (guess := match_numbers.search(r)) is not None else None \ | |
| for r in responses | |
| ] | |
| scores = [] | |
| # Print only every few steps | |
| global PRINTED_TIMES | |
| global PRINT_EVERY_STEPS | |
| if PRINTED_TIMES % PRINT_EVERY_STEPS == 0: | |
| print( | |
| '*'*20 + f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}" | |
| ) | |
| PRINTED_TIMES += 1 | |
| for guess, true_answer in zip(extracted_responses, answer): | |
| if guess is None: | |
| scores.append(-2.5) | |
| continue | |
| # Convert to numbers | |
| try: | |
| true_answer = float(true_answer.strip()) | |
| # Remove commas like in 123,456 | |
| guess = float(guess.strip().replace(",", "")) | |
| scores.append(3.5 if guess == true_answer else -1.5) | |
| except: | |
| scores.append(0) | |
| continue | |
| return scores | |
| # Get the top 90% prompt length so we don't accidentally truncate them! | |
| # Ie we'll remove the top 10% long prompts. | |
| tokenized = dataset.map( | |
| lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)}, | |
| batched = True, | |
| ) | |
| tokenized = tokenized.map(lambda x: {"L" : len(x["tokens"])}) | |
| maximum_length = int(np.quantile(tokenized["L"], 0.9)) | |
| # Filter only samples smaller than 90% max length | |
| dataset = dataset.select(np.where(np.array(tokenized["L"]) <= maximum_length)[0]) | |
| del tokenized | |
| # =========== | |
| # Training | | |
| # =========== | |
| max_prompt_length = maximum_length + 1 # + 1 just in case! | |
| max_completion_length = max_seq_length - max_prompt_length | |
| vllm_sampling_params = SamplingParams( | |
| min_p = 0.1, | |
| top_p = 1.0, | |
| top_k = -1, | |
| seed = 3407, | |
| stop = [tokenizer.eos_token], | |
| include_stop_str_in_output = True, | |
| ) | |
| training_args = GRPOConfig( | |
| vllm_sampling_params = vllm_sampling_params, | |
| temperature = 1.0, | |
| learning_rate = 5e-6, | |
| weight_decay = 0.01, | |
| warmup_ratio = 0.1, | |
| lr_scheduler_type = "linear", | |
| optim = "adamw_8bit", | |
| logging_steps = 1, | |
| per_device_train_batch_size = 1, | |
| gradient_accumulation_steps = 1, # Increase to 4 for smoother training | |
| num_generations = 4, # Decrease if out of memory | |
| max_prompt_length = max_prompt_length, | |
| max_completion_length = max_completion_length, | |
| # num_train_epochs = 1, # Set to 1 for a full training run | |
| max_steps = 100, | |
| save_steps = 100, | |
| report_to = "none", # Can use Weights & Biases | |
| output_dir = "outputs", | |
| ) | |
| trainer = GRPOTrainer( | |
| model = model, | |
| processing_class = tokenizer, | |
| reward_funcs = [ | |
| match_format_exactly, | |
| match_format_approximately, | |
| check_answer, | |
| check_numbers, | |
| ], | |
| args = training_args, | |
| train_dataset = dataset, | |
| ) | |
| # Show current memory stats | |
| gpu_stats = torch.cuda.get_device_properties(0) | |
| start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
| max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) | |
| print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") | |
| print(f"{start_gpu_memory} GB of memory reserved.") | |
| trainer_stats = trainer.train() | |
| # Show final memory and time stats | |
| used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
| used_memory_for_lora = round(used_memory - start_gpu_memory, 3) | |
| used_percentage = round(used_memory / max_memory * 100, 3) | |
| lora_percentage = round(used_memory_for_lora / max_memory * 100, 3) | |
| print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.") | |
| print( | |
| f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training." | |
| ) | |
| print(f"Peak reserved memory = {used_memory} GB.") | |
| print(f"Peak reserved memory for training = {used_memory_for_lora} GB.") | |
| print(f"Peak reserved memory % of max memory = {used_percentage} %.") | |
| print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.") | |
| # ============ | |
| # Inference | | |
| # ============ | |
| text = "What is the sqrt of 101?" | |
| sampling_params = SamplingParams( | |
| temperature = 1.0, | |
| top_k = 50, | |
| max_tokens = 1024, | |
| ) | |
| output = model.fast_generate( | |
| [text], | |
| sampling_params = sampling_params, | |
| lora_request = None, | |
| )[0].outputs[0].text | |
| print(output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment