Skip to content

Instantly share code, notes, and snippets.

@NickyDark1
NickyDark1 / grpo_demo.py
Created February 7, 2025 14:24 — forked from cgpeter96/grpo_demo.py
a grpo modifaction for deepspeed in multigpu from https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb
# train_grpo.py
from typing import *
import re
import torch
from datasets import load_dataset, Dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer, TrlParser
from dataclasses import dataclass, field