Created
February 9, 2025 15:11
-
-
Save ruvnet/5eabfef63adddd272d756e42b131deb7 to your computer and use it in GitHub Desktop.
Applying GRPO to DeepSeek-R1-Distill-Qwen-1.5B with LIMO Dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "3b71edb4", | |
"metadata": {}, | |
"source": [ | |
"# Applying GRPO to DeepSeek-R1-Distill-Qwen-1.5B with LIMO Dataset\n", | |
"\n", | |
"This notebook provides a step-by-step tutorial for applying **Generalized Reinforcement Policy Optimization (GRPO)** to the distilled model **DeepSeek-R1-Distill-Qwen-1.5B** using the high-quality LIMO dataset. We will cover:\n", | |
"\n", | |
"1. **Setup & Installation** – Installing dependencies and verifying GPU availability.\n", | |
"2. **Model & Dataset Preparation** – Loading the model, tokenizer, and dataset, and formatting prompts.\n", | |
"3. **Reinforcement Learning Fine-Tuning (GRPO)** – Implementing a simplified GRPO training loop, including reward computation and KL regularization.\n", | |
"4. **Evaluation & Performance Metrics** – Demonstrating how to evaluate the fine-tuned model on benchmark tasks.\n", | |
"5. **Hyperparameter Ablations & Future Directions** – Discussion on tuning and potential improvements.\n", | |
"\n", | |
"Let's begin!" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b55f4e70", | |
"metadata": {}, | |
"source": [ | |
"## 1. Setup & Installation\n", | |
"\n", | |
"We first install the necessary libraries including PyTorch, Hugging Face Transformers, TRL (for reinforcement learning), the Datasets library, and bitsandbytes for 8-bit optimization. Then, we verify that a GPU is available." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "7f4e7711", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!pip install transformers==4.48.2 trl==0.15.0.dev0 datasets bitsandbytes accelerate" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c6b72a2c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"print(\"Torch version:\", torch.__version__)\n", | |
"if torch.cuda.is_available():\n", | |
" device_name = torch.cuda.get_device_name(0)\n", | |
" print(\"GPU detected:\", device_name)\n", | |
" # Enable TF32 for faster matrix multiplication on supported GPUs\n", | |
" torch.backends.cuda.matmul.allow_tf32 = True\n", | |
"else:\n", | |
" print(\"No GPU found. Please enable a GPU runtime for training.\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "eb314410", | |
"metadata": {}, | |
"source": [ | |
"## 2. Model & Dataset Preparation\n", | |
"\n", | |
"We now load the **DeepSeek-R1-Distill-Qwen-1.5B** model and its tokenizer from Hugging Face, and load the LIMO dataset. The dataset consists of high-quality reasoning samples with a `question`, a detailed `solution`, and the final `answer`.\n", | |
"\n", | |
"We also define a helper function `format_prompt` that formats the question into a prompt instructing the model to output a reasoning chain and final answer using the tags `<think>` and `<answer>`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e3f3f80e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
"\n", | |
"model_name = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n", | |
"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)\n", | |
"model = AutoModelForCausalLM.from_pretrained(\n", | |
" model_name, \n", | |
" torch_dtype=torch.float16, \n", | |
" device_map=\"auto\"\n", | |
" # Uncomment the following line if the model requires custom code\n", | |
" # trust_remote_code=True\n", | |
")\n", | |
"\n", | |
"# Quick test generation\n", | |
"prompt_test = \"What is the capital of France?\"\n", | |
"inputs_test = tokenizer(prompt_test, return_tensors=\"pt\").to(model.device)\n", | |
"outputs_test = model.generate(**inputs_test, max_new_tokens=10)\n", | |
"print(\"Test output:\", tokenizer.decode(outputs_test[0], skip_special_tokens=True))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2b8f4d2e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from datasets import load_dataset\n", | |
"\n", | |
"# Load the LIMO dataset\n", | |
"dataset = load_dataset(\"GAIR/LIMO\")\n", | |
"train_data = dataset[\"train\"]\n", | |
"print(\"Total training samples:\", len(train_data))\n", | |
"\n", | |
"# Display a sample\n", | |
"sample = train_data[0]\n", | |
"print(\"Question:\", sample[\"question\"])\n", | |
"print(\"Solution (excerpt):\", sample[\"solution\"][:100] + \"...\")\n", | |
"print(\"Answer:\", sample[\"answer\"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "eaeb2a11", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def format_prompt(question):\n", | |
" \"\"\"\n", | |
" Format the prompt to instruct the model to output a chain-of-thought and final answer.\n", | |
" \"\"\"\n", | |
" instruction = (\n", | |
" \"Solve the following problem step by step, then give the final answer. \"\n", | |
" \"Format your response as: <think>[reasoning]</think><answer>[final answer]</answer>.\"\n", | |
" )\n", | |
" return f\"{instruction}\\nQuestion: {question}\\nSolution:\"\n", | |
"\n", | |
"# Test the formatting\n", | |
"formatted_prompt = format_prompt(sample[\"question\"])\n", | |
"print(formatted_prompt)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f0fce501", | |
"metadata": {}, | |
"source": [ | |
"## 3. Reinforcement Learning Fine-Tuning (GRPO)\n", | |
"\n", | |
"In this section, we implement a simplified GRPO training loop. The main steps include:\n", | |
"\n", | |
"- **Sampling:** For each prompt, we generate multiple outputs (a group) from the model.\n", | |
"- **Reward Scoring:** Compute a reward for each output based on answer accuracy and proper formatting.\n", | |
"- **Advantage Calculation:** Compute the advantage by comparing each reward to the group average.\n", | |
"- **Policy Optimization:** Update the model weights using the advantage-weighted log-likelihood loss along with a KL divergence penalty to keep the model close to the reference (base) policy.\n", | |
"\n", | |
"We use a default learning rate of `1e-6`, group size of 7, and a KL weight `β = 0.04`. We also set up an optimizer that supports 8-bit parameters (via bitsandbytes) for memory efficiency." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d0b7a87d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"from transformers import AdamW # Standard AdamW\n", | |
"\n", | |
"# Hyperparameters\n", | |
"learning_rate = 1e-6\n", | |
"tokens_per_generation = 4096 # Maximum tokens per generation (can be ablated)\n", | |
"group_size = 7\n", | |
"beta = 0.04\n", | |
"\n", | |
"# Initialize the 8-bit AdamW optimizer (using bitsandbytes)\n", | |
"import bitsandbytes as bnb\n", | |
"optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)\n", | |
"\n", | |
"# Optionally, use standard 32-bit AdamW:\n", | |
"# optimizer = AdamW(model.parameters(), lr=learning_rate)\n", | |
"\n", | |
"# Clone the initial model to serve as the reference for KL divergence\n", | |
"from transformers import AutoModelForCausalLM\n", | |
"ref_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(model.device)\n", | |
"ref_model.eval()\n", | |
"for param in ref_model.parameters():\n", | |
" param.requires_grad = False\n", | |
"\n", | |
"def reward_function(question, generated_text, true_answer):\n", | |
" \"\"\"\n", | |
" A simple rule-based reward:\n", | |
" - +0.1 bonus if output contains both <think> and <answer> tags\n", | |
" - +1.0 if the extracted answer matches the true answer\n", | |
" - Small penalty if no answer is extracted\n", | |
" \"\"\"\n", | |
" answer = None\n", | |
" if \"<answer>\" in generated_text and \"</answer>\" in generated_text:\n", | |
" start = generated_text.index(\"<answer>\") + len(\"<answer>\")\n", | |
" end = generated_text.index(\"</answer>\")\n", | |
" answer = generated_text[start:end].strip()\n", | |
" else:\n", | |
" # Fallback: take the last token as the answer\n", | |
" answer = generated_text.strip().split()[-1]\n", | |
"\n", | |
" reward = 0.0\n", | |
" # Bonus for proper formatting\n", | |
" if \"<think>\" in generated_text and \"</think>\" in generated_text and \"<answer>\" in generated_text and \"</answer>\" in generated_text:\n", | |
" reward += 0.1\n", | |
" \n", | |
" # Reward based on answer accuracy\n", | |
" if answer is not None:\n", | |
" pred_ans = answer.strip().strip('.')\n", | |
" true_ans = str(true_answer).strip().strip('.')\n", | |
" if pred_ans == true_ans:\n", | |
" reward += 1.0\n", | |
" else:\n", | |
" reward -= 0.1\n", | |
" \n", | |
" return reward\n", | |
"\n", | |
"print(\"Optimizer and reward function set up.\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "3c51b0ea", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import random\n", | |
"\n", | |
"model.train()\n", | |
"max_train_steps = 2 # Demo steps; in practice, use many more steps\n", | |
"grad_accum_steps = 8 # Effective batch: grad_accum_steps * group_size\n", | |
"\n", | |
"# Shuffle training indices\n", | |
"indices = list(range(len(train_data)))\n", | |
"random.shuffle(indices)\n", | |
"\n", | |
"step = 0\n", | |
"optimizer.zero_grad()\n", | |
"\n", | |
"for idx in indices[: max_train_steps * grad_accum_steps]:\n", | |
" question = train_data[idx][\"question\"]\n", | |
" true_answer = train_data[idx][\"answer\"]\n", | |
" prompt = format_prompt(question)\n", | |
" input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(model.device)\n", | |
" \n", | |
" # Generate a group of outputs\n", | |
" generated_texts = []\n", | |
" for _ in range(group_size):\n", | |
" output_ids = model.generate(\n", | |
" input_ids, \n", | |
" max_new_tokens=200, # For demo; in practice, use tokens_per_generation\n", | |
" do_sample=True, \n", | |
" temperature=1.0,\n", | |
" eos_token_id=tokenizer.convert_tokens_to_ids(\"</answer>\")\n", | |
" )\n", | |
" generated = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)\n", | |
" generated_texts.append(generated)\n", | |
" \n", | |
" # Compute rewards and advantages\n", | |
" rewards = [reward_function(question, text, true_answer) for text in generated_texts]\n", | |
" baseline = sum(rewards) / len(rewards)\n", | |
" advantages = [r - baseline for r in rewards]\n", | |
" \n", | |
" # Compute policy loss\n", | |
" policy_loss = 0.0\n", | |
" for text, adv in zip(generated_texts, advantages):\n", | |
" full_text = prompt + text\n", | |
" enc = tokenizer(full_text, return_tensors=\"pt\").to(model.device)\n", | |
" labels = enc.input_ids.clone()\n", | |
" labels[:, :input_ids.shape[1]] = -100 # Mask prompt tokens from loss\n", | |
" out = model(**enc, labels=labels)\n", | |
" # Multiply the average loss by the number of output tokens\n", | |
" policy_loss += adv * (out.loss * labels[:, input_ids.shape[1]:].numel())\n", | |
" policy_loss = policy_loss / group_size\n", | |
" \n", | |
" # Approximate KL divergence loss\n", | |
" kl_loss = 0.0\n", | |
" for text in generated_texts:\n", | |
" full_text = prompt + text\n", | |
" enc = tokenizer(full_text, return_tensors=\"pt\").to(model.device)\n", | |
" labels = enc.input_ids.clone()\n", | |
" labels[:, :input_ids.shape[1]] = -100\n", | |
" with torch.no_grad():\n", | |
" curr_out = model(**enc, labels=labels)\n", | |
" ref_out = ref_model(**enc, labels=labels)\n", | |
" curr_nll = curr_out.loss * labels[:, input_ids.shape[1]:].numel()\n", | |
" ref_nll = ref_out.loss * labels[:, input_ids.shape[1]:].numel()\n", | |
" kl_loss += (curr_nll - ref_nll) / labels[:, input_ids.shape[1]:].numel()\n", | |
" kl_loss = kl_loss / group_size\n", | |
" \n", | |
" total_loss = policy_loss + beta * kl_loss\n", | |
" total_loss.backward()\n", | |
" \n", | |
" if (idx + 1) % grad_accum_steps == 0:\n", | |
" optimizer.step()\n", | |
" optimizer.zero_grad()\n", | |
" step += 1\n", | |
" print(f\"Step {step}: policy_loss={policy_loss.item():.4f}, kl_loss={kl_loss.item():.4f}, rewards={rewards}\")\n", | |
" if step >= max_train_steps:\n", | |
" break\n", | |
"\n", | |
"model.eval()\n", | |
"print(\"Training demo completed.\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7e36ea3a", | |
"metadata": {}, | |
"source": [ | |
"## 4. Evaluation & Performance Metrics\n", | |
"\n", | |
"After fine-tuning, we evaluate the model on reasoning benchmarks (e.g., AIME24, GPQA, MATH-500). In this demo, we show an evaluation example for one benchmark. \n", | |
"\n", | |
"The process involves:\n", | |
"\n", | |
"- Formatting the prompt as during training.\n", | |
"- Generating an answer using greedy decoding.\n", | |
"- Extracting the answer using the `<answer>` tags and comparing it with the ground truth." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "97d86ce3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Example evaluation for a benchmark (e.g., AIME24)\n", | |
"# For illustration, let's assume we have lists of questions and true answers\n", | |
"\n", | |
"aime_questions = [\n", | |
" \"If x + y = 10 and x - y = 2, what is the value of x?\",\n", | |
" \"Compute the area of a circle with radius 7.\"\n", | |
"]\n", | |
"aime_answers = [\n", | |
" \"6\", # x = 6\n", | |
" \"153.938\" # Approximate area (could be rounded)\n", | |
"]\n", | |
"\n", | |
"model.eval()\n", | |
"correct = 0\n", | |
"for question, true_answer in zip(aime_questions, aime_answers):\n", | |
" prompt = format_prompt(question)\n", | |
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n", | |
" output_ids = model.generate(**inputs, max_new_tokens=512, temperature=0.0) # Greedy decoding\n", | |
" output_text = tokenizer.decode(output_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n", | |
" \n", | |
" if \"<answer>\" in output_text and \"</answer>\" in output_text:\n", | |
" ans = output_text.split(\"<answer>\")[1].split(\"</answer>\")[0].strip()\n", | |
" else:\n", | |
" ans = output_text.strip().split()[-1]\n", | |
" \n", | |
" print(f\"Question: {question}\")\n", | |
" print(f\"Predicted Answer: {ans}\")\n", | |
" print(f\"True Answer: {true_answer}\\n\")\n", | |
" \n", | |
" if str(ans).strip().strip('.') == str(true_answer).strip().strip('.'):\n", | |
" correct += 1\n", | |
"\n", | |
"accuracy = correct / len(aime_questions) * 100\n", | |
"print(f\"AIME24 Accuracy: {accuracy:.1f}%\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d7e82731", | |
"metadata": {}, | |
"source": [ | |
"## 5. Hyperparameter Ablations & Future Directions\n", | |
"\n", | |
"### Hyperparameter Ablations\n", | |
"\n", | |
"Key hyperparameters that can be tuned include:\n", | |
"\n", | |
"- **Learning Rate:** Our default is `1e-6`, but values like `2e-6`, `4e-6`, or `8e-6` may be experimented with.\n", | |
"- **Group Size:** Number of outputs per prompt (default is 7). Increasing this (e.g., 14, 28, or 56) can provide a more robust reward baseline but at higher computational cost.\n", | |
"- **KL Weight (β):** Default is `0.04`. Lower values (e.g., 0.01 or 0.001) allow the model more freedom to explore but may risk divergence.\n", | |
"\n", | |
"### Future Directions\n", | |
"\n", | |
"- **Refining the Reward Function:** Improve extraction of the final answer and consider partial rewards for nearly correct outputs.\n", | |
"- **Adaptive KL Penalty:** Use adaptive techniques to adjust β based on the observed KL divergence during training.\n", | |
"- **Scaling Up:** Experiment with larger models or longer generation tokens to fully exploit the reasoning capabilities.\n", | |
"- **Distillation vs. Pretrained Models:** Compare training outcomes when starting from a distilled model versus a base pretrained model.\n", | |
"\n", | |
"This concludes our step-by-step guide. Happy fine-tuning!" | |
] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"name": "GRPO_FineTuning_DeepSeek-R1-Distill-Qwen-1.5B.ipynb", | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.x" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment