Skip to content

Instantly share code, notes, and snippets.

@ruvnet
Created February 9, 2025 15:11
Show Gist options
  • Save ruvnet/5eabfef63adddd272d756e42b131deb7 to your computer and use it in GitHub Desktop.
Save ruvnet/5eabfef63adddd272d756e42b131deb7 to your computer and use it in GitHub Desktop.
Applying GRPO to DeepSeek-R1-Distill-Qwen-1.5B with LIMO Dataset
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "3b71edb4",
"metadata": {},
"source": [
"# Applying GRPO to DeepSeek-R1-Distill-Qwen-1.5B with LIMO Dataset\n",
"\n",
"This notebook provides a step-by-step tutorial for applying **Generalized Reinforcement Policy Optimization (GRPO)** to the distilled model **DeepSeek-R1-Distill-Qwen-1.5B** using the high-quality LIMO dataset. We will cover:\n",
"\n",
"1. **Setup & Installation** – Installing dependencies and verifying GPU availability.\n",
"2. **Model & Dataset Preparation** – Loading the model, tokenizer, and dataset, and formatting prompts.\n",
"3. **Reinforcement Learning Fine-Tuning (GRPO)** – Implementing a simplified GRPO training loop, including reward computation and KL regularization.\n",
"4. **Evaluation & Performance Metrics** – Demonstrating how to evaluate the fine-tuned model on benchmark tasks.\n",
"5. **Hyperparameter Ablations & Future Directions** – Discussion on tuning and potential improvements.\n",
"\n",
"Let's begin!"
]
},
{
"cell_type": "markdown",
"id": "b55f4e70",
"metadata": {},
"source": [
"## 1. Setup & Installation\n",
"\n",
"We first install the necessary libraries including PyTorch, Hugging Face Transformers, TRL (for reinforcement learning), the Datasets library, and bitsandbytes for 8-bit optimization. Then, we verify that a GPU is available."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f4e7711",
"metadata": {},
"outputs": [],
"source": [
"!pip install transformers==4.48.2 trl==0.15.0.dev0 datasets bitsandbytes accelerate"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6b72a2c",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"print(\"Torch version:\", torch.__version__)\n",
"if torch.cuda.is_available():\n",
" device_name = torch.cuda.get_device_name(0)\n",
" print(\"GPU detected:\", device_name)\n",
" # Enable TF32 for faster matrix multiplication on supported GPUs\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
"else:\n",
" print(\"No GPU found. Please enable a GPU runtime for training.\")"
]
},
{
"cell_type": "markdown",
"id": "eb314410",
"metadata": {},
"source": [
"## 2. Model & Dataset Preparation\n",
"\n",
"We now load the **DeepSeek-R1-Distill-Qwen-1.5B** model and its tokenizer from Hugging Face, and load the LIMO dataset. The dataset consists of high-quality reasoning samples with a `question`, a detailed `solution`, and the final `answer`.\n",
"\n",
"We also define a helper function `format_prompt` that formats the question into a prompt instructing the model to output a reasoning chain and final answer using the tags `<think>` and `<answer>`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3f3f80e",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"model_name = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name, \n",
" torch_dtype=torch.float16, \n",
" device_map=\"auto\"\n",
" # Uncomment the following line if the model requires custom code\n",
" # trust_remote_code=True\n",
")\n",
"\n",
"# Quick test generation\n",
"prompt_test = \"What is the capital of France?\"\n",
"inputs_test = tokenizer(prompt_test, return_tensors=\"pt\").to(model.device)\n",
"outputs_test = model.generate(**inputs_test, max_new_tokens=10)\n",
"print(\"Test output:\", tokenizer.decode(outputs_test[0], skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b8f4d2e",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"# Load the LIMO dataset\n",
"dataset = load_dataset(\"GAIR/LIMO\")\n",
"train_data = dataset[\"train\"]\n",
"print(\"Total training samples:\", len(train_data))\n",
"\n",
"# Display a sample\n",
"sample = train_data[0]\n",
"print(\"Question:\", sample[\"question\"])\n",
"print(\"Solution (excerpt):\", sample[\"solution\"][:100] + \"...\")\n",
"print(\"Answer:\", sample[\"answer\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eaeb2a11",
"metadata": {},
"outputs": [],
"source": [
"def format_prompt(question):\n",
" \"\"\"\n",
" Format the prompt to instruct the model to output a chain-of-thought and final answer.\n",
" \"\"\"\n",
" instruction = (\n",
" \"Solve the following problem step by step, then give the final answer. \"\n",
" \"Format your response as: <think>[reasoning]</think><answer>[final answer]</answer>.\"\n",
" )\n",
" return f\"{instruction}\\nQuestion: {question}\\nSolution:\"\n",
"\n",
"# Test the formatting\n",
"formatted_prompt = format_prompt(sample[\"question\"])\n",
"print(formatted_prompt)"
]
},
{
"cell_type": "markdown",
"id": "f0fce501",
"metadata": {},
"source": [
"## 3. Reinforcement Learning Fine-Tuning (GRPO)\n",
"\n",
"In this section, we implement a simplified GRPO training loop. The main steps include:\n",
"\n",
"- **Sampling:** For each prompt, we generate multiple outputs (a group) from the model.\n",
"- **Reward Scoring:** Compute a reward for each output based on answer accuracy and proper formatting.\n",
"- **Advantage Calculation:** Compute the advantage by comparing each reward to the group average.\n",
"- **Policy Optimization:** Update the model weights using the advantage-weighted log-likelihood loss along with a KL divergence penalty to keep the model close to the reference (base) policy.\n",
"\n",
"We use a default learning rate of `1e-6`, group size of 7, and a KL weight `β = 0.04`. We also set up an optimizer that supports 8-bit parameters (via bitsandbytes) for memory efficiency."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0b7a87d",
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"from transformers import AdamW # Standard AdamW\n",
"\n",
"# Hyperparameters\n",
"learning_rate = 1e-6\n",
"tokens_per_generation = 4096 # Maximum tokens per generation (can be ablated)\n",
"group_size = 7\n",
"beta = 0.04\n",
"\n",
"# Initialize the 8-bit AdamW optimizer (using bitsandbytes)\n",
"import bitsandbytes as bnb\n",
"optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)\n",
"\n",
"# Optionally, use standard 32-bit AdamW:\n",
"# optimizer = AdamW(model.parameters(), lr=learning_rate)\n",
"\n",
"# Clone the initial model to serve as the reference for KL divergence\n",
"from transformers import AutoModelForCausalLM\n",
"ref_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(model.device)\n",
"ref_model.eval()\n",
"for param in ref_model.parameters():\n",
" param.requires_grad = False\n",
"\n",
"def reward_function(question, generated_text, true_answer):\n",
" \"\"\"\n",
" A simple rule-based reward:\n",
" - +0.1 bonus if output contains both <think> and <answer> tags\n",
" - +1.0 if the extracted answer matches the true answer\n",
" - Small penalty if no answer is extracted\n",
" \"\"\"\n",
" answer = None\n",
" if \"<answer>\" in generated_text and \"</answer>\" in generated_text:\n",
" start = generated_text.index(\"<answer>\") + len(\"<answer>\")\n",
" end = generated_text.index(\"</answer>\")\n",
" answer = generated_text[start:end].strip()\n",
" else:\n",
" # Fallback: take the last token as the answer\n",
" answer = generated_text.strip().split()[-1]\n",
"\n",
" reward = 0.0\n",
" # Bonus for proper formatting\n",
" if \"<think>\" in generated_text and \"</think>\" in generated_text and \"<answer>\" in generated_text and \"</answer>\" in generated_text:\n",
" reward += 0.1\n",
" \n",
" # Reward based on answer accuracy\n",
" if answer is not None:\n",
" pred_ans = answer.strip().strip('.')\n",
" true_ans = str(true_answer).strip().strip('.')\n",
" if pred_ans == true_ans:\n",
" reward += 1.0\n",
" else:\n",
" reward -= 0.1\n",
" \n",
" return reward\n",
"\n",
"print(\"Optimizer and reward function set up.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c51b0ea",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"model.train()\n",
"max_train_steps = 2 # Demo steps; in practice, use many more steps\n",
"grad_accum_steps = 8 # Effective batch: grad_accum_steps * group_size\n",
"\n",
"# Shuffle training indices\n",
"indices = list(range(len(train_data)))\n",
"random.shuffle(indices)\n",
"\n",
"step = 0\n",
"optimizer.zero_grad()\n",
"\n",
"for idx in indices[: max_train_steps * grad_accum_steps]:\n",
" question = train_data[idx][\"question\"]\n",
" true_answer = train_data[idx][\"answer\"]\n",
" prompt = format_prompt(question)\n",
" input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(model.device)\n",
" \n",
" # Generate a group of outputs\n",
" generated_texts = []\n",
" for _ in range(group_size):\n",
" output_ids = model.generate(\n",
" input_ids, \n",
" max_new_tokens=200, # For demo; in practice, use tokens_per_generation\n",
" do_sample=True, \n",
" temperature=1.0,\n",
" eos_token_id=tokenizer.convert_tokens_to_ids(\"</answer>\")\n",
" )\n",
" generated = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)\n",
" generated_texts.append(generated)\n",
" \n",
" # Compute rewards and advantages\n",
" rewards = [reward_function(question, text, true_answer) for text in generated_texts]\n",
" baseline = sum(rewards) / len(rewards)\n",
" advantages = [r - baseline for r in rewards]\n",
" \n",
" # Compute policy loss\n",
" policy_loss = 0.0\n",
" for text, adv in zip(generated_texts, advantages):\n",
" full_text = prompt + text\n",
" enc = tokenizer(full_text, return_tensors=\"pt\").to(model.device)\n",
" labels = enc.input_ids.clone()\n",
" labels[:, :input_ids.shape[1]] = -100 # Mask prompt tokens from loss\n",
" out = model(**enc, labels=labels)\n",
" # Multiply the average loss by the number of output tokens\n",
" policy_loss += adv * (out.loss * labels[:, input_ids.shape[1]:].numel())\n",
" policy_loss = policy_loss / group_size\n",
" \n",
" # Approximate KL divergence loss\n",
" kl_loss = 0.0\n",
" for text in generated_texts:\n",
" full_text = prompt + text\n",
" enc = tokenizer(full_text, return_tensors=\"pt\").to(model.device)\n",
" labels = enc.input_ids.clone()\n",
" labels[:, :input_ids.shape[1]] = -100\n",
" with torch.no_grad():\n",
" curr_out = model(**enc, labels=labels)\n",
" ref_out = ref_model(**enc, labels=labels)\n",
" curr_nll = curr_out.loss * labels[:, input_ids.shape[1]:].numel()\n",
" ref_nll = ref_out.loss * labels[:, input_ids.shape[1]:].numel()\n",
" kl_loss += (curr_nll - ref_nll) / labels[:, input_ids.shape[1]:].numel()\n",
" kl_loss = kl_loss / group_size\n",
" \n",
" total_loss = policy_loss + beta * kl_loss\n",
" total_loss.backward()\n",
" \n",
" if (idx + 1) % grad_accum_steps == 0:\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" step += 1\n",
" print(f\"Step {step}: policy_loss={policy_loss.item():.4f}, kl_loss={kl_loss.item():.4f}, rewards={rewards}\")\n",
" if step >= max_train_steps:\n",
" break\n",
"\n",
"model.eval()\n",
"print(\"Training demo completed.\")"
]
},
{
"cell_type": "markdown",
"id": "7e36ea3a",
"metadata": {},
"source": [
"## 4. Evaluation & Performance Metrics\n",
"\n",
"After fine-tuning, we evaluate the model on reasoning benchmarks (e.g., AIME24, GPQA, MATH-500). In this demo, we show an evaluation example for one benchmark. \n",
"\n",
"The process involves:\n",
"\n",
"- Formatting the prompt as during training.\n",
"- Generating an answer using greedy decoding.\n",
"- Extracting the answer using the `<answer>` tags and comparing it with the ground truth."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97d86ce3",
"metadata": {},
"outputs": [],
"source": [
"# Example evaluation for a benchmark (e.g., AIME24)\n",
"# For illustration, let's assume we have lists of questions and true answers\n",
"\n",
"aime_questions = [\n",
" \"If x + y = 10 and x - y = 2, what is the value of x?\",\n",
" \"Compute the area of a circle with radius 7.\"\n",
"]\n",
"aime_answers = [\n",
" \"6\", # x = 6\n",
" \"153.938\" # Approximate area (could be rounded)\n",
"]\n",
"\n",
"model.eval()\n",
"correct = 0\n",
"for question, true_answer in zip(aime_questions, aime_answers):\n",
" prompt = format_prompt(question)\n",
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
" output_ids = model.generate(**inputs, max_new_tokens=512, temperature=0.0) # Greedy decoding\n",
" output_text = tokenizer.decode(output_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
" \n",
" if \"<answer>\" in output_text and \"</answer>\" in output_text:\n",
" ans = output_text.split(\"<answer>\")[1].split(\"</answer>\")[0].strip()\n",
" else:\n",
" ans = output_text.strip().split()[-1]\n",
" \n",
" print(f\"Question: {question}\")\n",
" print(f\"Predicted Answer: {ans}\")\n",
" print(f\"True Answer: {true_answer}\\n\")\n",
" \n",
" if str(ans).strip().strip('.') == str(true_answer).strip().strip('.'):\n",
" correct += 1\n",
"\n",
"accuracy = correct / len(aime_questions) * 100\n",
"print(f\"AIME24 Accuracy: {accuracy:.1f}%\")"
]
},
{
"cell_type": "markdown",
"id": "d7e82731",
"metadata": {},
"source": [
"## 5. Hyperparameter Ablations & Future Directions\n",
"\n",
"### Hyperparameter Ablations\n",
"\n",
"Key hyperparameters that can be tuned include:\n",
"\n",
"- **Learning Rate:** Our default is `1e-6`, but values like `2e-6`, `4e-6`, or `8e-6` may be experimented with.\n",
"- **Group Size:** Number of outputs per prompt (default is 7). Increasing this (e.g., 14, 28, or 56) can provide a more robust reward baseline but at higher computational cost.\n",
"- **KL Weight (β):** Default is `0.04`. Lower values (e.g., 0.01 or 0.001) allow the model more freedom to explore but may risk divergence.\n",
"\n",
"### Future Directions\n",
"\n",
"- **Refining the Reward Function:** Improve extraction of the final answer and consider partial rewards for nearly correct outputs.\n",
"- **Adaptive KL Penalty:** Use adaptive techniques to adjust β based on the observed KL divergence during training.\n",
"- **Scaling Up:** Experiment with larger models or longer generation tokens to fully exploit the reasoning capabilities.\n",
"- **Distillation vs. Pretrained Models:** Compare training outcomes when starting from a distilled model versus a base pretrained model.\n",
"\n",
"This concludes our step-by-step guide. Happy fine-tuning!"
]
}
],
"metadata": {
"colab": {
"name": "GRPO_FineTuning_DeepSeek-R1-Distill-Qwen-1.5B.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.x"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment