Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save abodacs/56a4ce14d9a43ce6074cbcfe4808fcf8 to your computer and use it in GitHub Desktop.
Save abodacs/56a4ce14d9a43ce6074cbcfe4808fcf8 to your computer and use it in GitHub Desktop.
grpo_qwen-0-5b_single_t4.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyOo1dcPaEBK34fiojzvmayR",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/qunash/820c86d1d267ec8051d9f68b4f4bb656/grpo_qwen-0-5b_single_t4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Full **GRPO** fine-tuning `Qwen2.5 0.5B` on a single T4\n",
"This colab uses a lot of tweaks and tricks to make GRPO **full fine-tuning** Qwen2.5-0.5-Instruct fit on a single T4 GPU, so that it could be run in a free Google Colab.\n",
"\n",
"It uses VLLm for fast inference and does not make compromises on batch and completion group sizes.\n",
"\n",
"With this setup you can improve `Qwen2.5-0.5B-Instruct`'s gsm8k eval result from 22.4% to 48.6% in just \\~150 steps (~30 minutes) on a single T4 GPU.\n",
"\n",
"\n",
"</br>\n",
"\n",
"---\n",
"\n",
"Here are some important optimizations used:\n",
"\n",
"* A [fork](https://github.com/andyl98/trl/tree/grpo-vram-optimization) of the TRL repo by [andyl98](https://github.com/andyl98), which introduces batched logprobs calculation. I then forked this fork and further optimized the logprobs computation function to reduce VRAM usage.\n",
"* 8-bit AdamW optimizer\n",
"* Set explicit memory allocation limits with `PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:128'`\n",
"\n",
"</br>\n",
"\n",
"---\n",
"\n",
"If using Ampere, or later architecture nvidia GPU, you can further reduce VRAM usage by:\n",
"\n",
"\n",
"* enabling `attn_implementation=\"flash_attention_2\"` during model loading\n",
"* loading the model with [Liger-Kernel](https://github.com/linkedin/Liger-Kernel) wrapper:\n",
"\n",
" ```Python\n",
" from liger_kernel.transformers import AutoLigerKernelForCausalLM\n",
" model = AutoLigerKernelForCausalLM.from_pretrained(\"path/to/some/model\")\n",
" ```\n",
"\n",
"[![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fgist.github.com%2Fqunash%2F820c86d1d267ec8051d9f68b4f4bb656&label=views&countColor=%23263759)](https://visitorbadge.io/status?path=https%3A%2F%2Fgist.github.com%2Fqunash%2F820c86d1d267ec8051d9f68b4f4bb656)"
],
"metadata": {
"id": "8oW2D1_PpNqF"
}
},
{
"cell_type": "code",
"source": [
"%%capture\n",
"!pip install uv\n",
"!uv pip install --system git+https://github.com/qunash/trl-1.git@grpo-vram-optimization\n",
"!uv pip install --system triton==2.2.0\n",
"!uv pip install --system vllm\n",
"!uv pip install --system bitsandbytes"
],
"metadata": {
"id": "znbQSsMqi7HJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"import re\n",
"import torch\n",
"from datasets import load_dataset, Dataset\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"from trl.trainer import GRPOConfig, GRPOTrainer\n",
"\n",
"\n",
"R1_STYLE_SYSTEM_PROMPT = \"\"\"A conversation between User and Assistant. The user asks a question, and the Assistant solves it.\n",
"The assistant first thinks about the reasoning process in the mind and then provides the user\n",
"with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and\n",
"<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>\n",
"<answer> answer here </answer>.\"\"\"\n",
"\n",
"TASK_SPECIFIC_INSTRUCTIONS = \"The answer must be a single integer.\"\n",
"\n",
"\n",
"def preprocess_dataset(dataset_name, split=\"train\", chunk_size=1000) -> Dataset:\n",
" dataset = load_dataset(dataset_name, 'main')[split]\n",
"\n",
" def extract_hash_answer(text: str) -> str | None:\n",
" try:\n",
" return text.split(\"####\")[1].strip()\n",
" except IndexError:\n",
" return None\n",
"\n",
" def process_batch(batch):\n",
" prompts = [[\n",
" {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + \"\\n\" + TASK_SPECIFIC_INSTRUCTIONS},\n",
" {'role': 'user', 'content': \"What is 2+2?\"},\n",
" {'role': 'assistant', 'content': \"<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\\n<answer>4</answer>\"},\n",
" {'role': 'user', 'content': q.strip()}\n",
" ] for q in batch['question']]\n",
"\n",
" return {\n",
" 'prompt': prompts,\n",
" 'answer': [extract_hash_answer(a) for a in batch['answer']]\n",
" }\n",
"\n",
" return dataset.map(process_batch, batched=True, batch_size=chunk_size)\n",
"\n",
"dataset_name = 'openai/gsm8k'\n",
"dataset = preprocess_dataset(dataset_name, chunk_size=500)\n",
"\n",
"\n",
"def extract_xml_answer(text: str) -> str:\n",
" try:\n",
" answer = text.split(\"<answer>\")[-1].split(\"</answer>\")[0].strip()\n",
" return answer\n",
" except IndexError:\n",
" return \"\"\n",
"\n",
"# reward functions\n",
"# VALID_FORMAT = re.compile(r\"<reasoning>(?:(?!</?reasoning>|</?answer>).)*</reasoning>\\n<answer>(?:(?!</?reasoning>|</?answer>).)*</answer>\")\n",
"\n",
"# def format_reward_func(completions, **kwargs) -> list[float]:\n",
"# \"\"\"Reward function that checks if the completion has the correct format.\"\"\"\n",
"# responses = [completion[0][\"content\"] for completion in completions]\n",
"# matches = [bool(VALID_FORMAT.fullmatch(r.strip())) for r in responses]\n",
"# return [1.0 if match else 0.0 for match in matches]\n",
"\n",
"def format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has the correct format.\"\"\"\n",
" pattern = r\"^<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>$\"\n",
" responses = [completion[0][\"content\"] for completion in completions]\n",
" matches = [bool(re.match(pattern, r)) for r in responses]\n",
" return [1.0 if match else 0.0 for match in matches]\n",
"\n",
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the answer is correct.\"\"\"\n",
" responses = [completion[0]['content'] for completion in completions]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" print(f\"Question: {prompts[0][-1]['content']}\\nAnswer: {answer[0]}\\nResponse: {responses[0]}\\nExtracted: {extracted_responses[0]}\")\n",
" print(''.join('✅' if r == a else '❌' for r, a in zip(extracted_responses, answer)))\n",
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
"\n",
"# model_name = \"Qwen/Qwen2.5-0.5B\"\n",
"model_name = \"Qwen/Qwen2.5-0.5B-Instruct\"\n",
"\n",
"output_dir = f\"outputs/{model_name.split('/')[-1]}-GRPO\"\n",
"run_name = f\"{model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}\"\n",
"\n",
"\n",
"# Set memory-related environment variables\n",
"os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'\n",
"\n",
"max_prompt_length=256\n",
"max_completion_length=512\n",
"\n",
"training_args = GRPOConfig(\n",
" output_dir=output_dir,\n",
" run_name=run_name,\n",
" learning_rate=1e-5,\n",
" beta=0.005, # divergence coefficient – how much the policy is allowed to deviate from the reference model. higher value – more conservative updates. Default is 0.04\n",
" optim=\"adamw_8bit\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.99,\n",
" weight_decay=0.1,\n",
" warmup_ratio=0.1,\n",
" lr_scheduler_type='cosine',\n",
" logging_steps=1,\n",
" bf16=True,\n",
" per_device_train_batch_size=4,\n",
" num_generations=4, # group size\n",
" gradient_accumulation_steps=4,\n",
" max_prompt_length=max_prompt_length,\n",
" max_completion_length=max_completion_length,\n",
" num_train_epochs=1,\n",
" save_steps=100,\n",
" max_grad_norm=0.1,\n",
" report_to=\"wandb\",\n",
" log_on_each_node=False,\n",
" use_vllm=True,\n",
" vllm_init_kwargs={\n",
" \"device\": \"cuda:0\",\n",
" \"gpu_memory_utilization\": 0.3,\n",
" \"max_model_len\": max_prompt_length + max_completion_length,\n",
" \"dtype\": \"half\",\n",
" # \"enable_chunked_prefill\": True,\n",
" # \"max_num_batched_tokens\": 2048,\n",
" },\n",
" gradient_checkpointing=True,\n",
" gradient_checkpointing_kwargs={\"use_reentrant\": False},\n",
" logit_computation_mini_batch_size=1,\n",
" enable_profiling=False\n",
")\n",
"\n",
"# Load model\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" torch_dtype=torch.bfloat16,\n",
" # attn_implementation=\"flash_attention_2\", # T4 is not supported\n",
" device_map=\"auto\",\n",
")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\n",
" model_name,\n",
" model_max_length=training_args.max_completion_length,\n",
")\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
"# Initialize trainer\n",
"trainer = GRPOTrainer(\n",
" model=model,\n",
" processing_class=tokenizer,\n",
" reward_funcs=[\n",
" format_reward_func,\n",
" correctness_reward_func\n",
" ],\n",
" args=training_args,\n",
" train_dataset=dataset,\n",
")\n",
"\n",
"trainer.train()\n"
],
"metadata": {
"id": "RnACYvTBWA1q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Eval"
],
"metadata": {
"id": "OFRHLvWhzHn0"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer\n",
"from vllm import LLM, SamplingParams\n",
"from tqdm.notebook import tqdm\n",
"import numpy as np\n",
"from typing import List, Dict\n",
"import json\n",
"from datetime import datetime\n",
"import logging\n",
"\n",
"# Disable VLLM's progress bars\n",
"logging.getLogger(\"vllm\").setLevel(logging.WARNING)\n",
"\n",
"# Constants from training script\n",
"R1_STYLE_SYSTEM_PROMPT = \"\"\"A conversation between User and Assistant. The user asks a question, and the Assistant solves it.\n",
"The assistant first thinks about the reasoning process in the mind and then provides the user\n",
"with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and\n",
"<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>\n",
"<answer> answer here </answer>.\"\"\"\n",
"\n",
"TASK_SPECIFIC_INSTRUCTIONS = \"The answer must be a single integer.\"\n",
"\n",
"def extract_xml_answer(text: str) -> str:\n",
" try:\n",
" answer = text.split(\"<answer>\")[-1].split(\"</answer>\")[0].strip()\n",
" return answer\n",
" except IndexError:\n",
" return \"\"\n",
"\n",
"def extract_hash_answer(text: str) -> str | None:\n",
" try:\n",
" return text.split(\"####\")[1].strip()\n",
" except IndexError:\n",
" return None\n",
"\n",
"def evaluate_model(\n",
" model_path: str,\n",
" batch_size: int = 4,\n",
" num_samples: int = None,\n",
" save_results: bool = True,\n",
" gpu_memory_utilization: float = 0.3,\n",
") -> Dict:\n",
" print(\"Initializing evaluation...\")\n",
"\n",
" # Initialize VLLM with progress indicator\n",
" with tqdm(total=2, desc=\"Loading model components\") as pbar:\n",
" llm = LLM(\n",
" model=model_path,\n",
" dtype=\"half\",\n",
" gpu_memory_utilization=gpu_memory_utilization,\n",
" max_model_len=768,\n",
" device=\"cuda:0\",\n",
" enable_chunked_prefill=True,\n",
" )\n",
" pbar.update(1)\n",
"\n",
" tokenizer = AutoTokenizer.from_pretrained(\n",
" model_path,\n",
" model_max_length=768,\n",
" padding_side='right',\n",
" truncation_side='right'\n",
" )\n",
" pbar.update(1)\n",
"\n",
" # Set up sampling parameters\n",
" sampling_params = SamplingParams(\n",
" temperature=0.0,\n",
" max_tokens=512, # Matching max_completion_length from training\n",
" stop_token_ids=[tokenizer.eos_token_id],\n",
" )\n",
"\n",
" # Load test dataset\n",
" print(\"Loading dataset...\")\n",
" dataset = load_dataset('openai/gsm8k', 'main', split='test')\n",
" if num_samples:\n",
" dataset = dataset.select(range(num_samples))\n",
" total_samples = len(dataset)\n",
" print(f\"Loaded {total_samples} samples\")\n",
"\n",
" results = []\n",
" correct = 0\n",
" total = 0\n",
"\n",
" # Create progress bar\n",
" progress_bar = tqdm(\n",
" total=total_samples,\n",
" desc=\"Processing samples\",\n",
" unit=\"examples\",\n",
" dynamic_ncols=True,\n",
" )\n",
"\n",
" progress_bar.set_postfix({\n",
" 'acc': '0.00%',\n",
" 'correct': '0',\n",
" })\n",
"\n",
" # Process in batches\n",
" for i in range(0, total_samples, batch_size):\n",
" batch_data = dataset[i:i + batch_size]\n",
" current_batch_size = len(batch_data['question'])\n",
"\n",
" # Prepare prompts using same format as training\n",
" prompts = [\n",
" [\n",
" {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + \"\\n\" + TASK_SPECIFIC_INSTRUCTIONS},\n",
" {'role': 'user', 'content': \"What is 2+2?\"},\n",
" {'role': 'assistant', 'content': \"<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\\n<answer>4</answer>\"},\n",
" {'role': 'user', 'content': q.strip()}\n",
" ] for q in batch_data['question']\n",
" ]\n",
"\n",
" # Convert to chat format\n",
" formatted_prompts = [\n",
" tokenizer.apply_chat_template(\n",
" p,\n",
" tokenize=False,\n",
" add_generation_prompt=True\n",
" )\n",
" for p in prompts\n",
" ]\n",
"\n",
" # Generate responses\n",
" outputs = llm.generate(\n",
" formatted_prompts,\n",
" sampling_params,\n",
" )\n",
"\n",
" # Process responses\n",
" for j, output in enumerate(outputs):\n",
" response = output.outputs[0].text\n",
"\n",
" # Extract answers\n",
" generated_answer = extract_xml_answer(response)\n",
" true_answer = extract_hash_answer(batch_data['answer'][j])\n",
"\n",
" # Store result\n",
" result = {\n",
" 'question': batch_data['question'][j],\n",
" 'true_answer': true_answer,\n",
" 'generated_answer': generated_answer,\n",
" 'full_response': response,\n",
" 'correct': generated_answer == true_answer\n",
" }\n",
" results.append(result)\n",
"\n",
" # Update metrics\n",
" if generated_answer == true_answer:\n",
" correct += 1\n",
" total += 1\n",
"\n",
" # Update progress\n",
" progress_bar.update(current_batch_size)\n",
" progress_bar.set_postfix({\n",
" 'acc': f'{(correct/total)*100:.2f}%',\n",
" 'correct': f'{correct}/{total}',\n",
" })\n",
"\n",
" progress_bar.close()\n",
"\n",
" # Calculate metrics\n",
" accuracy = correct / total if total > 0 else 0\n",
" metrics = {\n",
" 'accuracy': accuracy,\n",
" 'correct': correct,\n",
" 'total': total,\n",
" 'model_path': model_path,\n",
" 'timestamp': datetime.now().isoformat()\n",
" }\n",
"\n",
" # Save results\n",
" if save_results:\n",
" save_path = f\"gsm8k_eval_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json\"\n",
" with open(save_path, 'w') as f:\n",
" json.dump({\n",
" 'metrics': metrics,\n",
" 'results': results\n",
" }, f, indent=2)\n",
" print(f\"\\nResults saved to {save_path}\")\n",
"\n",
" return metrics\n",
"\n",
"print(\"Starting GSM8K evaluation...\")\n",
"checkpoint_path = \"outputs/Qwen2.5-0.5B-Instruct-GRPO/checkpoint-latest\" # Update path as needed\n",
"\n",
"metrics = evaluate_model(\n",
" model_path=checkpoint_path,\n",
" batch_size=4,\n",
" num_samples=None,\n",
" save_results=True,\n",
" gpu_memory_utilization=0.3,\n",
")\n",
"\n",
"print(\"\\nFinal Evaluation Results:\")\n",
"print(f\"Accuracy: {metrics['accuracy']:.2%}\")\n",
"print(f\"Correct: {metrics['correct']}/{metrics['total']}\")"
],
"metadata": {
"id": "nW6pJMSDD2sv"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment