Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created April 1, 2025 12:43
Show Gist options
  • Save vwxyzjn/9912288edabb679b0be6a406694d4858 to your computer and use it in GitHub Desktop.
Save vwxyzjn/9912288edabb679b0be6a406694d4858 to your computer and use it in GitHub Desktop.
# Copyright 2024 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------
# Part of the code is adapted from https://github.com/OpenRLHF/OpenRLHF
# which has the following license:
# Copyright [yyyy] [name of copyright owner]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# isort: off
from collections import defaultdict
import json
import os
import shutil
os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA
# isort: on
import logging
import os
import random
import socket
import threading
import time
import traceback
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from queue import Empty, Queue
from typing import Callable, Iterator, List, Literal, Optional
import deepspeed
import numpy as np
import pandas as pd
import ray
import torch
import torch.utils
import torch.utils.data
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from huggingface_hub import HfApi
from peft import PeftModel, get_peft_model_state_dict
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rich.pretty import pprint
from torch.utils.tensorboard import SummaryWriter
from transformers import (
AutoModelForCausalLM,
PreTrainedModel,
PreTrainedTokenizer,
get_scheduler,
)
from transformers.integrations import HfDeepSpeedConfig
from vllm import SamplingParams
from open_instruct.dataset_transformation import (
DATASET_SOURCE_KEY,
GROUND_TRUTHS_KEY,
INPUT_IDS_PROMPT_KEY,
TokenizerConfig,
get_cached_dataset_tulu,
visualize_token,
)
from open_instruct.ground_truth_utils import soft_format_reward_func
from open_instruct.model_utils import (
ModelConfig,
apply_verifiable_reward,
disable_dropout_in_model,
log_softmax_and_gather,
print_rich_single_line_metrics,
print_rich_table,
push_folder_to_hub,
)
from open_instruct.rl_utils2 import pack_sequences
from open_instruct.utils import (
ArgumentParserPlus,
BeakerRuntimeConfig,
get_wandb_tags,
is_beaker_job,
launch_ai2_evals_on_weka,
maybe_get_beaker_config,
maybe_use_ai2_hf_entity,
maybe_use_ai2_wandb_entity,
)
from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group
api = HfApi()
INVALID_LOGPROB = 1.0
@dataclass
class Args:
# Dataset
dataset_mixer_list: List[str] = field(default_factory=lambda: ["ai2-adapt-dev/rlvr_gsm8k_zs", "1.0"])
"""A list of datasets (local or HF) to sample from."""
dataset_mixer_eval_list: List[str] = field(default_factory=lambda: ["ai2-adapt-dev/rlvr_gsm8k_zs", "1.0"])
"""A list of datasets (local or HF) to sample from for evaluation."""
dataset_mixer_list_splits: List[str] = field(default_factory=lambda: ["train"])
"""The dataset splits to use for training"""
dataset_mixer_eval_list_splits: List[str] = field(default_factory=lambda: ["test"])
"""The dataset splits to use for evaluation"""
dataset_transform_fn: list[str] = field(default_factory=lambda: ["rlvr_tokenize_v1", "rlvr_filter_v1"])
"""The list of transform functions to apply to the dataset."""
dataset_cache_mode: Literal["hf", "local"] = "local"
"""The mode to use for caching the dataset."""
dataset_local_cache_dir: str = "local_dataset_cache"
"""The directory to save the local dataset cache to."""
dataset_config_hash: Optional[str] = None
"""The hash of the dataset configuration."""
dataset_config_eval_hash: Optional[str] = None
"""The hash of the dataset configuration for evaluation."""
dataset_skip_cache: bool = False
"""Whether to skip the cache."""
max_token_length: int = 512
"""The maximum token length to use for the dataset"""
max_prompt_token_length: int = 256
"""The maximum prompt token length to use for the dataset"""
# Experiment
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""The name of this experiment"""
seed: int = 1
"""Seed of the experiment"""
run_name: Optional[str] = None
"""RUNTIME VALUE: A unique name of this run"""
# Optimizer
learning_rate: float = 2e-5
"""The initial learning rate for AdamW optimizer."""
lr_scheduler_type: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "linear"
"""Which scheduler to use"""
warm_up_steps: int = 0
"""Number of warm up steps for the scheduler"""
warmup_ratio: float = 0.0
"""Ratio of warmup steps to total steps (takes precedence over `warm_up_steps`)"""
weight_decay: float = 0.0
"""Weight decay for AdamW if we apply some."""
set_weight_decay_on_bias_and_norm: bool = True
"""Whether to set weight decay on bias and norm layers"""
fused_optimizer: bool = False
"""Whether to use fused optimizer"""
# Batch sizes
per_device_train_batch_size: int = 1
"""The forward batch size per device (local_micro_batch_size)"""
total_episodes: int = 100000
"""The total number of episodes in the dataset"""
world_size: Optional[int] = None
"""RUNTIME VALUE: The number of processes (GPUs) to use"""
num_training_steps: Optional[int] = None
"""RUNTIME VALUE: The number of training_steps to train"""
num_evals: int = 10
"""The number of evaluations to run throughout training"""
eval_freq: Optional[int] = None
"""RUNTIME VALUE: The frequency of evaluation steps"""
save_freq: int = -1
"""How many train steps to save the model"""
# Generation
response_length: int = 256
"""the length of the response"""
temperature: float = 0.7
"""the sampling temperature"""
num_unique_prompts_rollout: int = 16
"""The number of unique prompts during rollout"""
num_samples_per_prompt_rollout: int = 4
"""the number of samples to generate per prompt during rollout, useful for easy-star"""
stop_strings: Optional[List[str]] = None
"""List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings."""
# Algorithm
async_mode: bool = True
"""Whether to run the generation in async mode which learns from the second latest policy like Cleanba (https://arxiv.org/abs/2310.00036)"""
num_epochs: int = 1
"""the number of epochs to train"""
num_mini_batches: int = 1
"""Number of minibatches to split a batch into"""
beta: float = 0.05
"""the beta value of the RLHF objective (KL coefficient)"""
cliprange: float = 0.2
"""the clip range"""
kl_estimator: Literal["kl1", "kl2", "kl3", "kl4"] = "kl3"
"""the KL estimator to use"""
pack_length: int = 512
"""the length of the pack (you should prob set to the max length of the model)"""
dr_grpo: bool = False
"""whether to use the DR-GRPO objective (https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf)"""
masked_mean_axis: Optional[int] = None
"""the axis to compute the mean of the masked values"""
# Reward
# -- r1 style format reward
apply_r1_style_format_reward: bool = False
"""whether to add the R1 style format reward"""
r1_style_format_reward: float = 1.0
"""the reward value for R1 style format reward"""
# -- verifiable reward
apply_verifiable_reward: bool = True
"""whether to apply verifiable reward"""
verification_reward: float = 10.0
"""the reward value for verifiable responses"""
# -- non stop penalty
non_stop_penalty: bool = False
"""whether to penalize responses which did not finish generation"""
non_stop_penalty_value: float = 0.0
"""the reward value for responses which did not finish generation"""
# -- arithmetic reward
apply_arithmetic_reward: bool = False
"""whether to apply arithmetic reward"""
arithmetic_reward: float = 10.0
"""the reward value for arithmetic responses"""
# Ray
single_gpu_mode: bool = False
"""whether to collocate vLLM and actor on the same node (mostly for debugging purposes)"""
num_learners_per_node: List[int] = field(default_factory=lambda: [1])
"""number of GPU deepspeed learners per node (e.g., --num_learners_per_node 2 4 means 2 learner processes
on the first node and 4 learner processes on the second node; each process will have 1 GPU)"""
vllm_num_engines: int = 1
"""number of vLLM Engines, set to 0 to disable vLLM"""
vllm_tensor_parallel_size: int = 1
"""tensor parallel size of vLLM Engine for multi-GPU inference"""
vllm_enforce_eager: bool = False
"""whether to enforce eager mode for vLLM -- slow inference but needed for multi-node"""
vllm_sync_backend: str = "nccl"
"""DeepSpeed -> vLLM weight sync backend"""
vllm_gpu_memory_utilization: float = 0.9
"""vLLM GPU memory utilization"""
vllm_enable_prefix_caching: bool = False
"""whether to enable prefix caching"""
deepspeed_stage: int = 0
"""the deepspeed stage"""
gather_whole_model: bool = True
"""whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)"""
# Experiment tracking
with_tracking: bool = False
"""If toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "open_instruct_internal"
"""The wandb's project name"""
wandb_entity: Optional[str] = None
"""The entity (team) of wandb's project"""
push_to_hub: bool = True
"""Whether to upload the saved model to huggingface"""
hf_entity: Optional[str] = None
"""The user or org name of the model repository from the Hugging Face Hub"""
hf_repo_id: Optional[str] = None
"""The id of the saved model in the Hugging Face Hub (can be autoset if not given)"""
hf_repo_revision: Optional[str] = None
"""The revision of the saved model in the Hugging Face Hub (can be autoset if not given)"""
hf_repo_url: Optional[str] = None
"""The url of the saved model in the Hugging Face Hub (will be autoset)"""
output_dir: str = "output"
"""Where to save the model"""
save_traces: bool = False
"""Whether to save learning data traces"""
cache_dataset_only: bool = False
"""Immediately exit after caching the dataset"""
# Ai2 specific settings
try_launch_beaker_eval_jobs_on_weka: bool = False
"""Whether to launch beaker evaluation jobs after training on weka"""
try_auto_save_to_beaker: bool = True
"""Whether to try to save the model to Beaker dataset `/output` after training"""
gs_bucket_path: Optional[str] = None
"""The path to the gs bucket to save the model to"""
oe_eval_tasks: Optional[List[str]] = None
"""The beaker evaluation tasks to launch"""
oe_eval_max_length: int = 4096
"""the max generation length for evaluation for oe-eval"""
eval_priority: Literal["low", "normal", "high", "urgent"] = "normal"
"""the priority of auto-launched evaluation jobs"""
def __post_init__(self):
if self.single_gpu_mode:
self.vllm_gpu_memory_utilization = 0.3
assert self.num_samples_per_prompt_rollout > 1, "Number of samples per prompt must be greater than 1 for GRPO!"
assert (
self.apply_verifiable_reward or self.apply_r1_style_format_reward or self.non_stop_penalty
), "At least one reward must be applied!"
assert (
self.pack_length >= self.max_prompt_token_length + self.response_length
), "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!"
def get_train_ds_config(
offload,
adam_offload=False,
stage=0,
bf16=True,
max_norm=1.0,
zpg=8,
grad_accum_dtype=None,
disable_trace_cache=True,
):
device = "cpu" if offload else "none"
zero_opt_dict = {
"stage": stage,
"offload_param": {"device": device},
"offload_optimizer": {
"device": "cpu" if adam_offload else "none",
"pin_memory": True,
},
"sub_group_size": "auto",
"stage3_max_live_parameters": "auto",
"stage3_max_reuse_distance": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_prefetch_bucket_size": "auto",
"reduce_bucket_size": "auto",
# # ZeRO++
# "zero_hpz_partition_size": zpg,
# "zero_quantized_weights": False,
# "zero_quantized_gradients": False,
}
if disable_trace_cache:
zero_opt_dict["stage3_prefetch_bucket_size"] = 0
zero_opt_dict["stage3_max_live_parameters"] = 0
zero_opt_dict["stage3_max_reuse_distance"] = 0
return {
"steps_per_print": 100,
"zero_optimization": zero_opt_dict,
"bf16": {
"enabled": bf16,
},
"gradient_clipping": max_norm,
"prescale_gradients": False,
"wall_clock_breakdown": False,
"data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"},
}
def get_eval_ds_config(
offload,
stage=0,
bf16=True,
):
zero_opt_dict = {
"stage": stage,
"stage3_param_persistence_threshold": "auto",
"offload_param": {
"device": "cpu" if offload else "none",
"pin_memory": True,
},
}
return {
"steps_per_print": 100,
"zero_optimization": zero_opt_dict,
"bf16": {
"enabled": bf16,
},
"prescale_gradients": False,
"wall_clock_breakdown": False,
}
def get_optimizer_grouped_parameters(
model: torch.nn.Module,
weight_decay: float,
no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"],
):
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay": weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if (any(nd in n for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
def _z3_params_to_fetch(param_list):
return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE]
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
return ((values * mask).sum(axis=axis) / mask.sum(axis=axis)).mean()
else:
return (values * mask).sum() / mask.sum()
def masked_sum(
values: torch.Tensor,
mask: torch.Tensor,
axis: Optional[bool] = None,
constant_normalizer: float = 1.0,
) -> torch.Tensor:
"""Compute sum of tensor with a masked values. Use a constant to normalize."""
if axis is not None:
return (values * mask).sum(axis=axis) / constant_normalizer
else:
return (values * mask).sum() / constant_normalizer
class MetricsTracker:
"""A simple class to prellocate all metrics in an array
so we can do only one allreduce operation to get the metrics mean"""
def __init__(self, max_metrics: int = 32, device: torch.device = torch.device("cuda")):
self.metrics = torch.zeros(max_metrics, device=device)
self.names2idx = {}
self.current_idx = 0
self.max_metrics = max_metrics
def add(self, name: str, value: torch.tensor):
if name not in self.names2idx:
if self.current_idx >= self.max_metrics:
raise ValueError(f"Exceeded maximum number of metrics ({self.max_metrics})")
self.names2idx[name] = self.current_idx
self.current_idx += 1
self.metrics[self.names2idx[name]] = value
return self
def get_metrics_list(self) -> dict[str, float]:
metrics_list = self.metrics.tolist()
return {name: metrics_list[idx] for name, idx in self.names2idx.items()}
def collate_fn(tensors_list: List[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor:
padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id)
if pin_memory:
padded_tensor = padded_tensor.pin_memory()
return padded_tensor
def to_device_inplace(tensors_list: List[torch.Tensor], device: torch.device):
for i in range(len(tensors_list)):
tensors_list[i] = tensors_list[i].to(device, non_blocking=True)
class Timer:
"""A context manager for timing code blocks"""
def __init__(self, description: str, noop: int = 0):
self.description = description
self.noop = noop
def __enter__(self):
if self.noop:
return
self.start_time = time.perf_counter()
return self
def __exit__(self, type, value, traceback):
if self.noop:
return
self.end_time = time.perf_counter()
self.duration = self.end_time - self.start_time
print(f"{self.description}: {self.duration:.2f} seconds")
class ShufflingIterator:
def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None):
self.data = data.copy()
self.batch_size = batch_size
self.index = 0
self.rng = np.random.default_rng(seed)
self.rng.shuffle(self.data)
# Ensure the effective dataset size is divisible by batch_size
self.effective_size = len(self.data) - (len(self.data) % batch_size)
def __iter__(self) -> Iterator[List[int]]:
return self
def __next__(self) -> List[int]:
if self.index >= self.effective_size:
self.index = 0
self.rng.shuffle(self.data)
end_index = self.index + self.batch_size
batch = self.data[self.index : end_index].tolist()
self.index = end_index
return batch
class RayProcess:
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
self.world_size = world_size
self.rank = rank
self.local_rank = local_rank
self.master_addr = master_addr if master_addr else self.get_current_node_ip()
self.master_port = master_port if master_port else self.get_free_port()
os.environ["MASTER_ADDR"] = self.master_addr
os.environ["MASTER_PORT"] = str(self.master_port)
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["RANK"] = str(self.rank)
# NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES
# environment variable for each actor, so always set device to 0
# os.environ["LOCAL_RANK"] = str(self._local_rank)
os.environ["LOCAL_RANK"] = "0"
random.seed(self.rank)
np.random.seed(self.rank)
torch.manual_seed(self.rank)
@staticmethod
def get_current_node_ip():
address = ray._private.services.get_node_ip_address()
# strip ipv6 address
return address.strip("[]")
@staticmethod
def get_free_port():
with socket.socket() as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
def get_master_addr_port(self):
return self.master_addr, self.master_port
def empty_cache(self) -> None:
torch.cuda.empty_cache()
@ray.remote(num_gpus=1)
class PolicyTrainerRayProcess(RayProcess):
def from_pretrained(
self,
args: Args,
model_config: ModelConfig,
beaker_config: BeakerRuntimeConfig,
wandb_url: str,
tokenizer: PreTrainedTokenizer,
):
self.args = args
self.tokenizer = tokenizer
self.model_config = model_config
self.beaker_config = beaker_config
self.wandb_url = wandb_url
torch.cuda.set_device(self.local_rank)
self.device = torch.device(self.local_rank)
deepspeed.init_distributed()
ds_config = get_train_ds_config(
offload=False,
adam_offload=False,
stage=args.deepspeed_stage,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
ds_config["gradient_accumulation_steps"] = 1
# @vwxyzjn: MAGIC: it's actually needed to initialize this `dschf`, so
# https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration
# next line instructs transformers to partition the model directly over multiple gpus using
# deepspeed.zero.Init when model's `from_pretrained` method is called.
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
HfDeepSpeedConfig(ds_config)
else:
pass
self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
)
disable_dropout_in_model(self.policy)
self.policy.gradient_checkpointing_enable()
# AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam
# AdamOptimizer = FusedAdam
if args.set_weight_decay_on_bias_and_norm:
optim_params = get_optimizer_grouped_parameters(self.policy, args.weight_decay)
else:
optim_params = self.policy.parameters()
# self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate)
self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer)
num_scheduler_steps = args.num_training_steps * args.num_epochs * args.num_mini_batches
warm_up_steps = args.warm_up_steps
if args.warmup_ratio > 0.0:
warm_up_steps = int(num_scheduler_steps * args.warmup_ratio)
scheduler = get_scheduler(
args.lr_scheduler_type,
optimizer=self.optimizer,
num_warmup_steps=warm_up_steps,
num_training_steps=num_scheduler_steps,
)
self.model, self.optimizer, _, self.scheduler = deepspeed.initialize(
model=self.policy,
optimizer=self.optimizer,
config=ds_config,
lr_scheduler=scheduler,
dist_init_required=True,
)
self.model.train()
# reference model
ds_config = get_eval_ds_config(
offload=False,
# inference model only has stage 3 (sharding) or stage 0 (no sharding)
# stage 2 is optimizer sharding which doesn't apply to inference
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0,
bf16=True,
)
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
ds_config["gradient_accumulation_steps"] = 1
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
HfDeepSpeedConfig(ds_config)
else:
pass
self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
)
disable_dropout_in_model(self.ref_policy)
self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config)
self.ref_policy.eval()
self.local_metrics = MetricsTracker(max_metrics=32, device=self.device)
def forward(
self,
model: PreTrainedModel,
query_response: torch.LongTensor,
attention_mask: torch.LongTensor,
position_ids: torch.LongTensor,
pad_token_id: int,
temperature: float,
) -> torch.Tensor:
# Replace pad tokens with 0s so that we don't run into index out of bounds errors
padding_mask = query_response != pad_token_id
input_ids = torch.masked_fill(query_response, ~padding_mask, 0)
# NOTE: the [:-1] and [1:] are because the logits and generated tokens are off by 1 in index
output = model(
input_ids=input_ids[:, :-1],
# @vwxyzjn: without clamp, we get index out of bounds errors; TODO: investigate
attention_mask=attention_mask[:, :-1].clamp(0, 1),
position_ids=position_ids[:, :-1],
return_dict=True,
)
logits = output.logits
logits /= temperature + 1e-7
logprob = log_softmax_and_gather(logits, input_ids[:, 1:])
return logprob
def setup_model_update_group(self, vllm_engines):
self.vllm_engines = vllm_engines
if self.rank == 0:
master_address = ray._private.services.get_node_ip_address()
with socket.socket() as sock:
sock.bind(("", 0))
master_port = sock.getsockname()[1]
vllm_num_engines, vllm_tensor_parallel_size = (
self.args.vllm_num_engines,
self.args.vllm_tensor_parallel_size,
)
world_size = vllm_num_engines * vllm_tensor_parallel_size + 1
backend = self.args.vllm_sync_backend
refs = [
engine.init_process_group.remote(
master_address,
master_port,
i * vllm_tensor_parallel_size + 1,
world_size,
"openrlhf",
backend=backend,
)
for i, engine in enumerate(vllm_engines)
]
self.model_update_group = init_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=0,
group_name="openrlhf",
)
ray.get(refs)
torch.distributed.barrier()
def broadcast_to_vllm(self):
# avoid OOM
torch.cuda.empty_cache()
model = self.model.module
count, num_params = 0, len(list(model.named_parameters()))
refss = []
if self.args.gather_whole_model:
with deepspeed.zero.GatheredParameters(model.parameters(), enabled=self.args.deepspeed_stage == 3):
for name, param in model.named_parameters():
count += 1 # empty_cache at last param
# Fire all vllm engines for broadcast
if torch.distributed.get_rank() == 0:
shape = param.shape if self.args.deepspeed_stage != 3 else param.ds_shape
refs = [
engine.update_weight.remote(
name, dtype=param.dtype, shape=shape, empty_cache=count == num_params
)
for engine in self.vllm_engines
]
refss.extend(refs)
if torch.distributed.get_rank() == 0:
torch.distributed.broadcast(param.data, 0, group=self.model_update_group)
else: # broadcast each parameter independently
for name, param in model.named_parameters():
count += 1
if torch.distributed.get_rank() == 0:
shape = param.shape if self.args.deepspeed_stage != 3 else param.ds_shape
refs = [
engine.update_weight.remote(
name, dtype=param.dtype, shape=shape, empty_cache=count == num_params
)
for engine in self.vllm_engines
]
refss.extend(refs)
with deepspeed.zero.GatheredParameters([param], enabled=self.args.deepspeed_stage == 3):
if torch.distributed.get_rank() == 0:
torch.distributed.broadcast(param.data, 0, group=self.model_update_group)
if torch.distributed.get_rank() == 0:
ray.get(refss)
def train(
self,
collated_query_responses,
collated_attention_masks,
collated_position_ids,
collated_advantages,
collated_response_masks,
pad_token_id: int,
num_mini_batches: int,
):
args = self.args
to_device_inplace(collated_query_responses, self.device)
to_device_inplace(collated_attention_masks, self.device)
to_device_inplace(collated_position_ids, self.device)
to_device_inplace(collated_advantages, self.device)
to_device_inplace(collated_response_masks, self.device)
accumulation_steps = len(collated_query_responses) // (num_mini_batches)
# Calculate the logprob of the reference policy
collated_ref_logprobs = []
with Timer("Inference Calculation", noop=self.rank != 0):
with torch.no_grad():
for i in range(len(collated_query_responses)):
query_response = collated_query_responses[i]
attention_mask = collated_attention_masks[i]
position_id = collated_position_ids[i]
response_mask = collated_response_masks[i]
ref_logprob = self.forward(
self.ref_policy,
query_response,
attention_mask,
position_id,
pad_token_id,
args.temperature,
)
ref_logprob = torch.masked_fill(ref_logprob, ~response_mask[:, 1:].bool(), INVALID_LOGPROB)
collated_ref_logprobs.append(ref_logprob)
torch.cuda.empty_cache()
local_step = 0
# Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch
with Timer("[Training Processes] Loss calculation", noop=self.rank != 0):
old_logprobs = [None for _ in range(len(collated_query_responses))]
kl1_stats = torch.zeros(len(collated_query_responses))
kl2_stats = torch.zeros(len(collated_query_responses))
kl3_stats = torch.zeros(len(collated_query_responses))
kl4_stats = torch.zeros(len(collated_query_responses))
pg_clipfrac_stats = torch.zeros(len(collated_query_responses))
pg_loss_stats = torch.zeros(len(collated_query_responses))
loss_stats = torch.zeros(len(collated_query_responses))
ratio_stats = torch.zeros(len(collated_query_responses))
for epoch_idx in range(args.num_epochs):
for i in range(len(collated_query_responses)):
mb_ref_logprob = collated_ref_logprobs[i]
mb_query_responses = collated_query_responses[i]
mb_advantages = collated_advantages[i]
mb_response_masks = collated_response_masks[i]
mb_response_masks_bool = mb_response_masks[:, 1:].bool()
mb_attention_mask = collated_attention_masks[i]
mb_position_id = collated_position_ids[i]
mb_new_logprobs = self.forward(
self.model,
mb_query_responses,
mb_attention_mask,
mb_position_id,
pad_token_id,
args.temperature,
)
mb_new_logprobs = torch.masked_fill(mb_new_logprobs, ~mb_response_masks_bool, INVALID_LOGPROB)
# Cache the old logprobs
with torch.no_grad():
if epoch_idx == 0:
old_logprobs[i] = mb_new_logprobs
mb_old_logprobs = old_logprobs[i].detach()
# Calculate the policy's loss
logprobs_diff = mb_new_logprobs - mb_old_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantages[:, 1:] * ratio
pg_losses2 = -mb_advantages[:, 1:] * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
# Here we recalculate kl: we want the KL loss to backpropagate through the model
# We also clamp the KL loss to avoid numerical instability
# https://chatgpt.com/share/679d0ed9-8f48-8011-926e-e274b15ae8ae
ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0)
kl1 = ref_logprobs_diff
kl2 = (ref_logprobs_diff) ** 2 / 2
kl3 = torch.expm1(-ref_logprobs_diff) + ref_logprobs_diff # this is more numerically stable
kl4 = ratio * ref_logprobs_diff
if args.kl_estimator == "kl1":
kl = kl1
elif args.kl_estimator == "kl2":
kl = kl2
elif args.kl_estimator == "kl3":
kl = kl3
elif args.kl_estimator == "kl4":
kl = kl4
# grpo change: directly subtract KL in loss (add)
if args.dr_grpo:
loss = masked_sum(pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis, constant_normalizer=args.response_length)
else:
loss = masked_mean(pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis)
loss = loss / accumulation_steps
self.model.backward(loss)
if (local_step + 1) % accumulation_steps == 0:
self.model.step()
local_step += 1
with torch.no_grad():
# NOTE: in packed implementation, kl calculation are averages over response tokens
kl1_stats[i] = masked_mean(kl1, mb_response_masks_bool, args.masked_mean_axis).float()
kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, args.masked_mean_axis).float()
kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, args.masked_mean_axis).float()
kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, args.masked_mean_axis).float()
pg_clipfrac_stats[i] = masked_mean(
(pg_losses2 > pg_losses).float(), mb_response_masks_bool, args.masked_mean_axis
)
pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool, args.masked_mean_axis)
loss_stats[i] = loss
ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool, args.masked_mean_axis)
with torch.no_grad():
self.local_metrics.add("objective/kl_avg", kl1_stats.mean())
self.local_metrics.add("objective/kl2_avg", kl2_stats.mean())
self.local_metrics.add("objective/kl3_avg", kl3_stats.mean())
self.local_metrics.add("objective/kl4_avg", kl4_stats.mean())
self.local_metrics.add("loss/policy_avg", pg_loss_stats.mean())
self.local_metrics.add("loss/policy_avg", loss_stats.mean())
self.local_metrics.add("policy/clipfrac_avg", pg_clipfrac_stats.mean())
self.local_metrics.add("val/ratio", ratio_stats.mean())
self.local_metrics.add("val/ratio_var", ratio_stats.var())
self.local_metrics.add("lr", self.scheduler.get_last_lr()[0])
return self.local_metrics.get_metrics_list()
def save_model(self, output_dir: str) -> None:
model_to_save = self.model
if self.rank == 0:
os.makedirs(output_dir, exist_ok=True)
# save model weights for ZeRO2/3
if hasattr(model_to_save, "module"):
model_to_save = model_to_save.module
# gather parameters
output_state_dict = {}
for k, v in model_to_save.named_parameters():
# only gather z3 params
params_to_fetch = _z3_params_to_fetch([v])
with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0):
vv = v.data.cpu()
if self.rank == 0:
output_state_dict[k] = vv
if self.rank == 0:
state_dict = model_to_save.state_dict()
# copy named_buffers with `persistent=True`
for k, v in model_to_save.named_buffers():
if k not in state_dict:
continue
vv = v.data.cpu()
output_state_dict[k] = vv
state_dict_keys = set(state_dict.keys())
output_state_dict_keys = set(output_state_dict.keys())
# corner case for tie_word_embeddings, such as Qwen2-0.5B
if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys:
state_dict_keys.remove("lm_head.weight")
assert state_dict_keys.issubset(
output_state_dict_keys
), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}"
# only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295
if isinstance(model_to_save, PeftModel):
model_to_save.save_pretrained(output_dir)
if self.stage == 3:
torch.save(
get_peft_model_state_dict(model_to_save, output_state_dict),
os.path.join(output_dir, "adapter_model.bin"),
)
else:
model_to_save.save_pretrained(output_dir, state_dict=output_state_dict)
# save tokenizer
self.tokenizer.save_pretrained(output_dir)
# we need this because we don't know which node is rank 0 is on
def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url, training_step):
args = self.args
if self.rank == 0:
future = (
ray.remote(launch_ai2_evals_on_weka)
.options(num_cpus=1)
.remote(
step_dir,
leaderboard_name,
args.oe_eval_max_length,
wandb_url,
training_step,
args.oe_eval_tasks,
args.stop_strings,
args.gs_bucket_path,
args.eval_priority,
)
)
else:
future = None
return future
class ModelGroup:
def __init__(
self,
pg: PlacementGroup,
ray_process_cls: RayProcess,
num_gpus_per_node: List[int],
single_gpu_mode: bool,
):
self.pg = pg
self.ray_process_cls = ray_process_cls
self.num_gpus_per_node = num_gpus_per_node
self.num_gpus_per_actor = 0.48 if single_gpu_mode else 1
self.num_cpus_per_actor = 4
self.models = []
world_size = sum(self.num_gpus_per_node)
master_policy = ray_process_cls.options(
num_cpus=self.num_cpus_per_actor,
num_gpus=self.num_gpus_per_actor,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=self.pg, placement_group_bundle_index=0
),
).remote(world_size, 0, 0, None, None)
self.models.append(master_policy)
master_addr, master_port = ray.get(master_policy.get_master_addr_port.remote())
def get_bundle_index(rank, num_gpus_per_node):
"""given a rank and a list of num_gpus_per_node, return the index of the bundle that the rank belongs to"""
bundle_idx = 0
while rank >= num_gpus_per_node[bundle_idx]:
rank -= num_gpus_per_node[bundle_idx]
bundle_idx += 1
return bundle_idx
assert get_bundle_index(0, [7, 8, 4]) == 0
assert get_bundle_index(1, [7, 8, 4]) == 0
assert get_bundle_index(7, [7, 8, 4]) == 1
assert get_bundle_index(8, [7, 8, 4]) == 1
assert get_bundle_index(9, [7, 8, 4]) == 1
assert get_bundle_index(16, [7, 8, 4]) == 2
# Setup worker models
for rank in range(1, world_size):
print(f"{rank=}, {world_size=}, {rank=}, {master_addr=}, {master_port=}")
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=self.pg,
placement_group_bundle_index=get_bundle_index(rank, self.num_gpus_per_node),
)
worker_policy = ray_process_cls.options(
num_cpus=self.num_cpus_per_actor,
num_gpus=self.num_gpus_per_actor,
scheduling_strategy=scheduling_strategy,
).remote(world_size, rank, 0, master_addr, master_port)
self.models.append(worker_policy)
def vllm_generate_thread(
vllm_engines: List[ray.actor.ActorHandle],
generation_config: SamplingParams,
eval_generation_config: SamplingParams,
inference_results_Q: Queue,
param_prompt_Q: Queue,
num_training_steps: int,
eval_prompt_token_ids: Optional[List[int]],
evaluation_inference_results_Q: Queue,
eval_freq: int,
resume_training_step: int = 1,
):
def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingParams):
# Split queries between engines
queries_per_engine = (len(prompts) + len(vllm_engines) - 1) // len(vllm_engines)
split_queries = [prompts[i : i + queries_per_engine] for i in range(0, len(prompts), queries_per_engine)]
# Generate responses in parallel across engines
futures = [
vllm_engine.generate.remote(sampling_params=sampling_params, prompt_token_ids=queries, use_tqdm=False)
for vllm_engine, queries in zip(vllm_engines, split_queries)
]
# Gather all responses
all_outputs = ray.get(futures)
response_ids = []
finish_reasons = [] # either "stop" or "length"
for outputs in all_outputs:
response_ids.extend([list(out.token_ids) for output in outputs for out in output.outputs])
finish_reasons.extend([out.finish_reason for output in outputs for out in output.outputs])
return response_ids, finish_reasons
for training_step in range(resume_training_step, num_training_steps + 1):
items = param_prompt_Q.get()
if items is None:
break
_, g_queries_list = items
with Timer("🔥 Generation time"):
response_ids, finish_reasons = generate_with_engines(g_queries_list, generation_config)
inference_results_Q.put((response_ids, finish_reasons))
# Evaluate the model
if eval_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0:
response_ids, finish_reasons = generate_with_engines(eval_prompt_token_ids, eval_generation_config)
evaluation_inference_results_Q.put((response_ids, finish_reasons))
def data_preparation_thread(
reward_fn: Callable,
inference_results_Q: Queue,
packed_sequences_Q: Queue,
queries_prompt_Q: Queue,
args: Args,
tokenizer: PreTrainedTokenizer,
num_training_steps: int,
):
for training_step in range(1, num_training_steps + 1):
# Get next batch of prompts and responses
items = queries_prompt_Q.get()
queries, ground_truths, datasets = items
# ------------------------------------------------------------------------------------------------
# Pack sequences
if args.num_samples_per_prompt_rollout > 1:
queries = [item for item in queries for _ in range(args.num_samples_per_prompt_rollout)]
ground_truths = [item for item in ground_truths for _ in range(args.num_samples_per_prompt_rollout)]
datasets = [item for item in datasets for _ in range(args.num_samples_per_prompt_rollout)]
with Timer("🚀 [Data Preparation Thread] Getting response ids"):
responses, finish_reasons = inference_results_Q.get()
for i in range(len(finish_reasons)):
if finish_reasons[i] == "stop" and responses[i][-1] != tokenizer.eos_token_id:
responses[i].append(tokenizer.eos_token_id)
with Timer("📦 [Data Preparation Thread] Packing sequences"):
packed_sequences = pack_sequences(
queries=queries,
responses=responses,
pack_length=args.pack_length,
pad_token_id=tokenizer.pad_token_id,
)
num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses)
with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
decoded_responses = tokenizer.batch_decode(responses, skip_special_tokens=True)
stop_rate = sum(int(finish_reason == "stop") for finish_reason in finish_reasons) / len(finish_reasons)
with Timer("💰 [Data Preparation Thread] Calculating rewards"):
scores, reward_metrics = reward_fn(responses, decoded_responses, ground_truths, datasets, finish_reasons)
with Timer("🎆 [Data Preparation Thread] Calculating advantages"):
# Calculate advantages
scores = np.array(scores)
print(f"[Data Preparation Thread] {len(scores)=}")
scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout)
global_mean_grouped_rewards = scores_per_prompt.mean(axis=-1)
global_mean_grouped_rewards = np.repeat(
global_mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0
)
if args.dr_grpo:
global_advantages = (scores - global_mean_grouped_rewards)
else:
global_std_grouped_rewards = scores_per_prompt.std(axis=-1)
global_std_grouped_rewards = np.repeat(
global_std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0
)
global_advantages = (scores - global_mean_grouped_rewards) / (global_std_grouped_rewards + 1e-8)
# Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value
# and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed)
lookup_advantages = np.zeros(len(global_advantages) + 1, dtype=np.float32)
lookup_advantages[1:] = global_advantages
packed_advantages = [
torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32)
for packed_mask in packed_sequences.response_masks
]
packed_sequences.advantages = packed_advantages
with Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker"):
B = (
len(packed_sequences.query_responses) // args.world_size
) # essentially doing `drop_last=True`, which is fine.
collated_data = []
for i in range(args.world_size):
per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)]
per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)]
per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)]
per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)]
per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)]
# Shuffle the batch and collate the data
b_inds = np.random.permutation(len(per_device_packed_query_responses))
collated_query_responses = []
collated_attention_masks = []
collated_position_ids = []
collated_response_masks = []
collated_advantages = []
for j in range(0, len(per_device_packed_query_responses), args.per_device_train_batch_size):
micro_range = b_inds[j : j + args.per_device_train_batch_size]
collated_query_responses.append(
collate_fn(
[per_device_packed_query_responses[idx] for idx in micro_range], tokenizer.pad_token_id
)
)
collated_attention_masks.append(
collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0)
)
collated_position_ids.append(
collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0)
)
collated_response_masks.append(
collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0)
)
collated_advantages.append(
collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0)
)
collated_data.append(
{
"collated_query_responses": collated_query_responses,
"collated_attention_masks": collated_attention_masks,
"collated_position_ids": collated_position_ids,
"collated_advantages": collated_advantages,
"collated_response_masks": collated_response_masks,
}
)
# Create a result package with metrics and data
sequence_lengths = np.array([len(response) for response in responses])
metrics = {
"scores": np.array(scores).mean(),
"val/sequence_lengths": sequence_lengths.mean(),
"val/sequence_lengths_min": sequence_lengths.min(),
"val/sequence_lengths_max": sequence_lengths.max(),
"val/stop_rate": stop_rate,
**reward_metrics,
}
if args.save_traces:
traces = {
"scores": scores.tolist(),
"finish_reasons": finish_reasons,
"responses": responses,
"queries": queries,
"ground_truths": ground_truths,
"datasets": datasets,
"training_step": training_step,
**reward_metrics,
}
os.makedirs(args.output_dir, exist_ok=True)
with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f:
json.dump(traces, f)
f.write("\n")
# Put the packed sequences and metrics into the output queue
packed_sequences_Q.put(
{
"packed_sequences": packed_sequences, # for debugging purposes
"collated_data": collated_data,
"metrics": metrics,
"responses_count": len(responses),
"num_new_tokens": num_new_tokens,
"B": B,
}
)
def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn: Callable):
# ------------------------------------------------------------
# Setup tokenizer
tc.tokenizer_revision = model_config.model_revision if tc.tokenizer_revision is None else tc.tokenizer_revision
tc.tokenizer_name_or_path = (
model_config.model_name_or_path if tc.tokenizer_name_or_path is None else tc.tokenizer_name_or_path
)
if (
tc.tokenizer_revision != model_config.model_revision
and tc.tokenizer_name_or_path != model_config.model_name_or_path
):
# Warn user if tokenizer and model use different revisions; this is an unusual
# use case.
warning = f"""Requested tokenizer revision `{tc.tokenizer_revision=}` is different
from the model revision `{model_config.model_revision=}` or the tokenizer name `{tc.tokenizer_name_or_path=}`
is different from the model name `{model_config.model_name_or_path=}`."""
print(warning)
tokenizer = tc.tokenizer
# ------------------------------------------------------------
# Set up runtime variables
args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"
args.output_dir = os.path.join(args.output_dir, args.run_name)
args.dataset_local_cache_dir = os.path.abspath(args.dataset_local_cache_dir)
if is_beaker_job():
args.dataset_local_cache_dir = "/weka/oe-adapt-default/allennlp/deletable_open_instruct_dataset_cache"
args.world_size = sum(args.num_learners_per_node)
args.num_training_steps = args.total_episodes // (
args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
)
args.eval_freq = max(1, args.num_training_steps // args.num_evals)
args.try_launch_beaker_eval_jobs_on_weka = args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job()
if args.push_to_hub:
if args.hf_repo_id is None: # auto-generate one
args.hf_repo_id = "open_instruct_dev"
if args.hf_entity is None: # first try to use AI2 entity
args.hf_entity = maybe_use_ai2_hf_entity()
if args.hf_entity is None: # then try to use the user's entity
args.hf_entity = HfApi().whoami()["name"]
args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
if args.hf_repo_revision is None: # auto-generate one
args.hf_repo_revision = args.run_name
args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}"
if args.with_tracking:
if args.wandb_entity is None:
args.wandb_entity = maybe_use_ai2_wandb_entity()
# ------------------------------------------------------------
# Setup experiment tracking and seeds
all_configs = {}
beaker_config = None
if is_beaker_job():
beaker_config = maybe_get_beaker_config()
all_configs.update(vars(beaker_config))
all_configs.update(**asdict(args), **asdict(tc), **asdict(model_config))
if args.with_tracking:
import wandb
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=all_configs,
name=args.run_name,
save_code=True,
tags=[args.exp_name] + get_wandb_tags(),
)
writer = SummaryWriter(f"runs/{args.run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# ------------------------------------------------------------
# Set up datasets
transform_fn_args = [
{},
{
"max_token_length": args.max_token_length,
"max_prompt_token_length": args.max_prompt_token_length,
},
]
train_dataset = get_cached_dataset_tulu(
dataset_mixer_list=args.dataset_mixer_list,
dataset_mixer_list_splits=args.dataset_mixer_list_splits,
tc=tc,
dataset_transform_fn=args.dataset_transform_fn,
transform_fn_args=transform_fn_args,
dataset_cache_mode=args.dataset_cache_mode,
dataset_config_hash=args.dataset_config_hash,
hf_entity=args.hf_entity,
dataset_local_cache_dir=args.dataset_local_cache_dir,
dataset_skip_cache=args.dataset_skip_cache,
)
train_dataset = train_dataset.shuffle(seed=args.seed)
eval_dataset = None
if len(args.dataset_mixer_eval_list) > 0:
eval_dataset = get_cached_dataset_tulu(
args.dataset_mixer_eval_list,
args.dataset_mixer_eval_list_splits,
tc,
args.dataset_transform_fn,
transform_fn_args,
hf_entity=args.hf_entity,
dataset_cache_mode=args.dataset_cache_mode,
dataset_config_hash=args.dataset_config_eval_hash,
dataset_local_cache_dir=args.dataset_local_cache_dir,
dataset_skip_cache=args.dataset_skip_cache,
)
eval_dataset = eval_dataset.shuffle(seed=args.seed)
if args.cache_dataset_only:
return
# ------------------------------------------------------------
# Runtime setups and quick logging
pprint([args, model_config])
visualize_token(train_dataset[0][INPUT_IDS_PROMPT_KEY], tokenizer)
# ------------------------------------------------------------
# Create the model and optimizer
pg = None
bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node]
pg = placement_group(bundles, strategy="STRICT_SPREAD")
ray.get(pg.ready())
inits = []
policy_group = ModelGroup(
pg,
PolicyTrainerRayProcess,
args.num_learners_per_node,
args.single_gpu_mode,
)
wandb_url = wandb.run.get_url() if args.with_tracking else None
inits.extend(
model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer)
for model in policy_group.models
)
max_len = args.max_prompt_token_length + args.response_length
vllm_engines = create_vllm_engines(
args.vllm_num_engines,
args.vllm_tensor_parallel_size,
args.vllm_enforce_eager,
model_config.model_name_or_path,
model_config.model_revision,
args.seed,
args.vllm_enable_prefix_caching,
max_len,
args.vllm_gpu_memory_utilization,
args.single_gpu_mode,
pg=pg if args.single_gpu_mode else None,
)
ray.get(inits)
print("======== ✅ all models and vLLM engines initialized =========")
ray.get([m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models])
print("======== ✅ model update group setup successfully =========")
# Setup training
generation_config = SamplingParams(
temperature=args.temperature,
top_p=1.0,
max_tokens=args.response_length,
include_stop_str_in_output=True,
n=args.num_samples_per_prompt_rollout,
stop=args.stop_strings,
)
eval_generation_config = SamplingParams(
temperature=0.0,
top_p=1.0,
max_tokens=args.response_length,
include_stop_str_in_output=True,
n=1, # since we are doing greedy sampling, don't need to generate more
stop=args.stop_strings,
)
train_dataset_idxs = np.arange(len(train_dataset))
iter_dataloader = ShufflingIterator(train_dataset_idxs, args.num_unique_prompts_rollout, seed=args.seed)
inference_results_Q = Queue(maxsize=1)
param_prompt_Q = Queue(maxsize=1)
evaluation_inference_results_Q = Queue(maxsize=1)
packed_sequences_Q = Queue(maxsize=1)
queries_prompt_Q = Queue(maxsize=1)
num_eval_samples = 32
eval_prompt_token_ids = None
eval_ground_truths = None
if eval_dataset is not None:
eval_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY]
eval_ground_truths = eval_dataset[:num_eval_samples][GROUND_TRUTHS_KEY]
resume_training_step = 1
thread = threading.Thread(
target=vllm_generate_thread,
args=(
vllm_engines,
generation_config,
eval_generation_config,
inference_results_Q,
param_prompt_Q,
args.num_training_steps,
eval_prompt_token_ids,
evaluation_inference_results_Q,
args.eval_freq,
resume_training_step,
),
)
thread.start()
print("======== ✅ vllm generate thread starts =========")
packing_thread = threading.Thread(
target=data_preparation_thread,
args=(
reward_fn,
inference_results_Q,
packed_sequences_Q,
queries_prompt_Q,
args,
tokenizer,
args.num_training_steps,
),
)
packing_thread.start()
print("======== ✅ data preparation thread starts =========")
# Send initial data to both threads
data_next = train_dataset[next(iter_dataloader)]
queries_next = data_next[INPUT_IDS_PROMPT_KEY]
ground_truths_next = data_next[GROUND_TRUTHS_KEY]
datasets_next = data_next[DATASET_SOURCE_KEY]
queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next))
param_prompt_Q.put((None, queries_next))
episode = 0
num_total_tokens = 0
start_time = time.time()
eval_futures = []
try:
for training_step in range(resume_training_step, args.num_training_steps + 1):
print("-" * 100)
episode += (
args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
) # each sample is an episode
# ------------------------------------------------------------------------------------------------
# Optionally evaluate the model
try:
evaluation_responses, _ = evaluation_inference_results_Q.get(timeout=0.01)
print("[Main Thread] 📊 Evaluation responses received")
table = {}
table["prompt"] = tokenizer.batch_decode(eval_prompt_token_ids)
table["response"] = tokenizer.batch_decode(evaluation_responses)
table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]]
table["ground_truth"] = eval_ground_truths
df = pd.DataFrame(table)
if args.with_tracking:
wandb.log({"sample_completions": wandb.Table(dataframe=df)})
else:
print_rich_table(df.iloc[:1])
del table
except Empty:
print("[Main Thread] 🙈 Evaluation responses not received")
# ------------------------------------------------------------------------------------------------
# Sync weights and send the next batch of prompts to vLLM
if args.async_mode:
if training_step != 1:
data_next = train_dataset[next(iter_dataloader)]
queries_next = data_next[INPUT_IDS_PROMPT_KEY]
ground_truths_next = data_next[GROUND_TRUTHS_KEY]
datasets_next = data_next[DATASET_SOURCE_KEY]
with Timer("[Main Thread] 🔄 Loading weights using shared memory"):
ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models])
queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next))
param_prompt_Q.put((None, queries_next))
else:
if training_step != 1:
# NOTE: important: the indent here is different for sync mode
# we also set to use `queries = queries_next` immediately
data_next = train_dataset[next(iter_dataloader)]
queries_next = data_next[INPUT_IDS_PROMPT_KEY]
ground_truths_next = data_next[GROUND_TRUTHS_KEY]
datasets_next = data_next[DATASET_SOURCE_KEY]
with Timer("🔄 Loading weights using shared memory"):
ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models])
queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next))
param_prompt_Q.put((None, queries_next))
# ------------------------------------------------------------------------------------------------
# Get the packed sequences with advantages from the packing thread
with Timer("[Main Thread] 📦 Getting packed sequences from thread"):
packed_data = packed_sequences_Q.get()
packed_sequences = packed_data["packed_sequences"]
data_thread_metrics = packed_data["metrics"]
B = packed_data["B"]
collated_data = packed_data["collated_data"]
num_total_tokens += packed_data["num_new_tokens"]
# Log info about the packed sequences
print(
f"Number of training examples per device: {B=}, packed sequence fraction of original sequences: {len(packed_sequences.query_responses) / packed_data['responses_count']}"
)
if B == 0:
print("[Main Thread] 🤡 After packing, there is not enough data to train")
continue
# ------------------------------------------------------------------------------------------------
# Train the model
with Timer("[Main Thread] 🗡️ Training"):
metrics_list: List[dict[str, float]] = ray.get(
[
policy_group.models[i].train.remote(
**collated_data[i],
pad_token_id=tokenizer.pad_token_id,
num_mini_batches=args.num_mini_batches,
)
for i in range(args.world_size)
]
)
average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]}
metrics = {
"episode": episode,
"training_step": training_step,
"val/num_total_tokens": num_total_tokens,
"epoch": episode / args.num_samples_per_prompt_rollout / len(train_dataset),
"tokens_per_second": num_total_tokens / (time.time() - start_time),
**data_thread_metrics,
**average_metrics,
}
print_rich_single_line_metrics(metrics)
for key, value in metrics.items():
writer.add_scalar(key, value, episode)
if args.save_freq > 0 and training_step % args.save_freq == 0:
with Timer("[Main Thread] 🗡️ Saving model"):
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
print(f"Saving model at step {training_step} to {step_dir}")
ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)])
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}"
eval_futures.extend(
[
policy_group.models[i].launch_ai2_evals_on_weka_wrapper.remote(
step_dir, leaderboard_name, wandb_url, training_step
)
for i in range(args.world_size)
]
)
print(f"Saving final model at step {training_step} to {args.output_dir}")
with Timer("[Main Thread] 🗡️ Saving model"):
ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)])
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = args.hf_repo_revision
eval_futures.extend(
[
policy_group.models[i].launch_ai2_evals_on_weka_wrapper.remote(
args.output_dir, leaderboard_name, wandb_url, training_step
)
for i in range(args.world_size)
]
)
except Exception as e:
print(f"Training error occurred: {str(e)}")
print(traceback.format_exc())
ray.shutdown()
os._exit(1)
raise # Re-raise the exception after shutdown
# Clean up threads
thread.join()
print("======== ✅ vllm generate thread ends =========")
packing_thread.join()
print("======== ✅ data preparation thread ends =========")
ray.shutdown()
# Ai2 logic: we use /output to store the artifacts of the job, so we
# make a copy of the model to `/output` in the end.
if (
args.try_auto_save_to_beaker
and is_beaker_job()
and len(beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir.rstrip("/") != "/output"
):
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
print("finished training")
accelerator = Namespace()
accelerator.is_main_process = True # hack
if args.push_to_hub:
print("Pushing model to hub")
push_folder_to_hub(
accelerator,
args.output_dir,
args.hf_repo_id,
args.hf_repo_revision,
)
if __name__ == "__main__":
parser = ArgumentParserPlus((Args, TokenizerConfig, ModelConfig))
args, tokenizer_config, model_config = parser.parse_args_into_dataclasses()
assert isinstance(args, Args)
assert isinstance(tokenizer_config, TokenizerConfig)
assert isinstance(model_config, ModelConfig)
def reward_fn(
responses: List[torch.Tensor],
decoded_responses: List[str],
ground_truths: List[str],
datasets: List[str],
finish_reasons: List[str],
) -> List[float]:
scores = [0] * len(decoded_responses)
metrics = {}
if args.apply_r1_style_format_reward:
with Timer("[Data Preparation Thread] Calculating rewards -- 🧮 Calculating format reward"):
format_scores = soft_format_reward_func(decoded_responses, args.r1_style_format_reward)
if len(format_scores) != len(scores):
raise ValueError(f"{len(format_scores)=} != {len(scores)=}")
for i in range(len(format_scores)):
scores[i] = format_scores[i] + scores[i]
metrics["val/format_scores"] = np.array(format_scores).mean()
if args.apply_verifiable_reward:
with Timer("[Data Preparation Thread] Calculating rewards -- 🏆 Applying verifiable reward"):
verifiable_rewards, per_func_rewards = apply_verifiable_reward(
responses,
decoded_responses,
ground_truths,
datasets,
reward_mult=args.verification_reward,
)
if len(verifiable_rewards) != len(scores):
raise ValueError(f"{len(verifiable_rewards)=} != {len(scores)=}")
for i in range(len(verifiable_rewards)):
scores[i] = verifiable_rewards[i] + scores[i]
np_verifiable_rewards = np.array(verifiable_rewards)
metrics["objective/verifiable_reward"] = np_verifiable_rewards.mean()
metrics["objective/verifiable_correct_rate"] = (np_verifiable_rewards > 0.0).mean()
# reshuffle around per_func rewards
per_func_lists = defaultdict(list)
for reward_dict in per_func_rewards:
for key, value in reward_dict.items():
per_func_lists[key].append(value)
# log per function rewards
for key, value in per_func_lists.items():
np_value = np.array(value)
metrics[f"objective/{key}_reward"] = np_value.mean()
metrics[f"objective/{key}_correct_rate"] = (np_value > 0.0).mean()
# this gets applied at the very end since it replaces (rather than adds to) the existing reward.
if args.non_stop_penalty:
with Timer("[Data Preparation Thread] Calculating rewards -- 🦖 Applying non stop penalty"):
assert len(finish_reasons) == len(scores)
for i in range(len(finish_reasons)):
if finish_reasons[i] != "stop":
scores[i] = args.non_stop_penalty_value
# @nouha: handle arithmetic reward
if args.apply_arithmetic_reward:
with Timer("[Data Preparation Thread] Calculating rewards -- 🧮 Calculating arithmetic reward"):
arithmetic_rewards = []
for i in range(len(decoded_responses)):
# extract the string between <answer> and </answer>
decoded_response = decoded_responses[i]
answer_start = decoded_response.find("<answer>") + len("<answer>")
answer_end = decoded_response.find("</answer>")
# normalize the number (e.g., 1,000 -> 1000)
try:
answer = decoded_response[answer_start:answer_end]
answer = answer.replace(",", "").strip()
if float(answer) == float(ground_truths[i]):
arithmetic_rewards.append(args.arithmetic_reward)
scores[i] += args.arithmetic_reward
else:
arithmetic_rewards.append(0)
except: # noqa
arithmetic_rewards.append(0)
pass # it's ok if things went wrong
metrics["objective/arithmetic_score"] = np.array(arithmetic_rewards).mean()
metrics["objective/arithmetic_correct_rate"] = (np.array(arithmetic_rewards) > 0.0).mean()
return scores, metrics
main(args, tokenizer_config, model_config, reward_fn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment