Skip to content

Instantly share code, notes, and snippets.

@lewtun
Last active January 29, 2025 10:51
Show Gist options
  • Save lewtun/d3c1ac9dbe96514b8fd6fafcc657f1bc to your computer and use it in GitHub Desktop.
Save lewtun/d3c1ac9dbe96514b8fd6fafcc657f1bc to your computer and use it in GitHub Desktop.
GRPO with vLLM demo
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import random
"""Usage (on 8 x H100s):
pip install vllm==0.7.0 --extra-index-url https://download.pytorch.org/whl/cu121
pip install -e '.[dev]'
# DDP
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml --num_processes 7 scratch/grpo_demo.py
# ZeRO-2
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml --num_processes 7 scratch/grpo_demo.py
# ZeRO-3
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --num_processes 7 scratch/grpo_demo.py
# FSDP
accelerate launch --config_file examples/accelerate_configs/fsdp.yaml --num_processes 7 scratch/grpo_demo.py
"""
def random_reward(completions, **kwargs):
return [random.random() for _ in completions]
def main():
# Load the dataset
dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:5%]")
training_args = GRPOConfig(
output_dir="Qwen2-0.5B-GRPO",
logging_steps=2,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
max_prompt_length=64,
max_completion_length=32,
num_generations=4,
num_train_epochs=1,
use_vllm=True,
vllm_device="auto",
vllm_gpu_memory_utilization=0.7,
bf16=True
)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=random_reward,
args=training_args,
train_dataset=dataset,
)
trainer.train()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment