Created
April 1, 2025 12:43
-
-
Save vwxyzjn/9912288edabb679b0be6a406694d4858 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2024 AllenAI. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# --------------------------------------------------------------------- | |
# Part of the code is adapted from https://github.com/OpenRLHF/OpenRLHF | |
# which has the following license: | |
# Copyright [yyyy] [name of copyright owner] | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# isort: off | |
from collections import defaultdict | |
import json | |
import os | |
import shutil | |
os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA | |
# isort: on | |
import logging | |
import os | |
import random | |
import socket | |
import threading | |
import time | |
import traceback | |
from argparse import Namespace | |
from dataclasses import asdict, dataclass, field | |
from queue import Empty, Queue | |
from typing import Callable, Iterator, List, Literal, Optional | |
import deepspeed | |
import numpy as np | |
import pandas as pd | |
import ray | |
import torch | |
import torch.utils | |
import torch.utils.data | |
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus | |
from huggingface_hub import HfApi | |
from peft import PeftModel, get_peft_model_state_dict | |
from ray.util.placement_group import PlacementGroup, placement_group | |
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy | |
from rich.pretty import pprint | |
from torch.utils.tensorboard import SummaryWriter | |
from transformers import ( | |
AutoModelForCausalLM, | |
PreTrainedModel, | |
PreTrainedTokenizer, | |
get_scheduler, | |
) | |
from transformers.integrations import HfDeepSpeedConfig | |
from vllm import SamplingParams | |
from open_instruct.dataset_transformation import ( | |
DATASET_SOURCE_KEY, | |
GROUND_TRUTHS_KEY, | |
INPUT_IDS_PROMPT_KEY, | |
TokenizerConfig, | |
get_cached_dataset_tulu, | |
visualize_token, | |
) | |
from open_instruct.ground_truth_utils import soft_format_reward_func | |
from open_instruct.model_utils import ( | |
ModelConfig, | |
apply_verifiable_reward, | |
disable_dropout_in_model, | |
log_softmax_and_gather, | |
print_rich_single_line_metrics, | |
print_rich_table, | |
push_folder_to_hub, | |
) | |
from open_instruct.rl_utils2 import pack_sequences | |
from open_instruct.utils import ( | |
ArgumentParserPlus, | |
BeakerRuntimeConfig, | |
get_wandb_tags, | |
is_beaker_job, | |
launch_ai2_evals_on_weka, | |
maybe_get_beaker_config, | |
maybe_use_ai2_hf_entity, | |
maybe_use_ai2_wandb_entity, | |
) | |
from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group | |
api = HfApi() | |
INVALID_LOGPROB = 1.0 | |
@dataclass | |
class Args: | |
# Dataset | |
dataset_mixer_list: List[str] = field(default_factory=lambda: ["ai2-adapt-dev/rlvr_gsm8k_zs", "1.0"]) | |
"""A list of datasets (local or HF) to sample from.""" | |
dataset_mixer_eval_list: List[str] = field(default_factory=lambda: ["ai2-adapt-dev/rlvr_gsm8k_zs", "1.0"]) | |
"""A list of datasets (local or HF) to sample from for evaluation.""" | |
dataset_mixer_list_splits: List[str] = field(default_factory=lambda: ["train"]) | |
"""The dataset splits to use for training""" | |
dataset_mixer_eval_list_splits: List[str] = field(default_factory=lambda: ["test"]) | |
"""The dataset splits to use for evaluation""" | |
dataset_transform_fn: list[str] = field(default_factory=lambda: ["rlvr_tokenize_v1", "rlvr_filter_v1"]) | |
"""The list of transform functions to apply to the dataset.""" | |
dataset_cache_mode: Literal["hf", "local"] = "local" | |
"""The mode to use for caching the dataset.""" | |
dataset_local_cache_dir: str = "local_dataset_cache" | |
"""The directory to save the local dataset cache to.""" | |
dataset_config_hash: Optional[str] = None | |
"""The hash of the dataset configuration.""" | |
dataset_config_eval_hash: Optional[str] = None | |
"""The hash of the dataset configuration for evaluation.""" | |
dataset_skip_cache: bool = False | |
"""Whether to skip the cache.""" | |
max_token_length: int = 512 | |
"""The maximum token length to use for the dataset""" | |
max_prompt_token_length: int = 256 | |
"""The maximum prompt token length to use for the dataset""" | |
# Experiment | |
exp_name: str = os.path.basename(__file__)[: -len(".py")] | |
"""The name of this experiment""" | |
seed: int = 1 | |
"""Seed of the experiment""" | |
run_name: Optional[str] = None | |
"""RUNTIME VALUE: A unique name of this run""" | |
# Optimizer | |
learning_rate: float = 2e-5 | |
"""The initial learning rate for AdamW optimizer.""" | |
lr_scheduler_type: Literal[ | |
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" | |
] = "linear" | |
"""Which scheduler to use""" | |
warm_up_steps: int = 0 | |
"""Number of warm up steps for the scheduler""" | |
warmup_ratio: float = 0.0 | |
"""Ratio of warmup steps to total steps (takes precedence over `warm_up_steps`)""" | |
weight_decay: float = 0.0 | |
"""Weight decay for AdamW if we apply some.""" | |
set_weight_decay_on_bias_and_norm: bool = True | |
"""Whether to set weight decay on bias and norm layers""" | |
fused_optimizer: bool = False | |
"""Whether to use fused optimizer""" | |
# Batch sizes | |
per_device_train_batch_size: int = 1 | |
"""The forward batch size per device (local_micro_batch_size)""" | |
total_episodes: int = 100000 | |
"""The total number of episodes in the dataset""" | |
world_size: Optional[int] = None | |
"""RUNTIME VALUE: The number of processes (GPUs) to use""" | |
num_training_steps: Optional[int] = None | |
"""RUNTIME VALUE: The number of training_steps to train""" | |
num_evals: int = 10 | |
"""The number of evaluations to run throughout training""" | |
eval_freq: Optional[int] = None | |
"""RUNTIME VALUE: The frequency of evaluation steps""" | |
save_freq: int = -1 | |
"""How many train steps to save the model""" | |
# Generation | |
response_length: int = 256 | |
"""the length of the response""" | |
temperature: float = 0.7 | |
"""the sampling temperature""" | |
num_unique_prompts_rollout: int = 16 | |
"""The number of unique prompts during rollout""" | |
num_samples_per_prompt_rollout: int = 4 | |
"""the number of samples to generate per prompt during rollout, useful for easy-star""" | |
stop_strings: Optional[List[str]] = None | |
"""List of strings that stop the generation when they are generated. | |
The returned output will not contain the stop strings.""" | |
# Algorithm | |
async_mode: bool = True | |
"""Whether to run the generation in async mode which learns from the second latest policy like Cleanba (https://arxiv.org/abs/2310.00036)""" | |
num_epochs: int = 1 | |
"""the number of epochs to train""" | |
num_mini_batches: int = 1 | |
"""Number of minibatches to split a batch into""" | |
beta: float = 0.05 | |
"""the beta value of the RLHF objective (KL coefficient)""" | |
cliprange: float = 0.2 | |
"""the clip range""" | |
kl_estimator: Literal["kl1", "kl2", "kl3", "kl4"] = "kl3" | |
"""the KL estimator to use""" | |
pack_length: int = 512 | |
"""the length of the pack (you should prob set to the max length of the model)""" | |
dr_grpo: bool = False | |
"""whether to use the DR-GRPO objective (https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf)""" | |
masked_mean_axis: Optional[int] = None | |
"""the axis to compute the mean of the masked values""" | |
# Reward | |
# -- r1 style format reward | |
apply_r1_style_format_reward: bool = False | |
"""whether to add the R1 style format reward""" | |
r1_style_format_reward: float = 1.0 | |
"""the reward value for R1 style format reward""" | |
# -- verifiable reward | |
apply_verifiable_reward: bool = True | |
"""whether to apply verifiable reward""" | |
verification_reward: float = 10.0 | |
"""the reward value for verifiable responses""" | |
# -- non stop penalty | |
non_stop_penalty: bool = False | |
"""whether to penalize responses which did not finish generation""" | |
non_stop_penalty_value: float = 0.0 | |
"""the reward value for responses which did not finish generation""" | |
# -- arithmetic reward | |
apply_arithmetic_reward: bool = False | |
"""whether to apply arithmetic reward""" | |
arithmetic_reward: float = 10.0 | |
"""the reward value for arithmetic responses""" | |
# Ray | |
single_gpu_mode: bool = False | |
"""whether to collocate vLLM and actor on the same node (mostly for debugging purposes)""" | |
num_learners_per_node: List[int] = field(default_factory=lambda: [1]) | |
"""number of GPU deepspeed learners per node (e.g., --num_learners_per_node 2 4 means 2 learner processes | |
on the first node and 4 learner processes on the second node; each process will have 1 GPU)""" | |
vllm_num_engines: int = 1 | |
"""number of vLLM Engines, set to 0 to disable vLLM""" | |
vllm_tensor_parallel_size: int = 1 | |
"""tensor parallel size of vLLM Engine for multi-GPU inference""" | |
vllm_enforce_eager: bool = False | |
"""whether to enforce eager mode for vLLM -- slow inference but needed for multi-node""" | |
vllm_sync_backend: str = "nccl" | |
"""DeepSpeed -> vLLM weight sync backend""" | |
vllm_gpu_memory_utilization: float = 0.9 | |
"""vLLM GPU memory utilization""" | |
vllm_enable_prefix_caching: bool = False | |
"""whether to enable prefix caching""" | |
deepspeed_stage: int = 0 | |
"""the deepspeed stage""" | |
gather_whole_model: bool = True | |
"""whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)""" | |
# Experiment tracking | |
with_tracking: bool = False | |
"""If toggled, this experiment will be tracked with Weights and Biases""" | |
wandb_project_name: str = "open_instruct_internal" | |
"""The wandb's project name""" | |
wandb_entity: Optional[str] = None | |
"""The entity (team) of wandb's project""" | |
push_to_hub: bool = True | |
"""Whether to upload the saved model to huggingface""" | |
hf_entity: Optional[str] = None | |
"""The user or org name of the model repository from the Hugging Face Hub""" | |
hf_repo_id: Optional[str] = None | |
"""The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" | |
hf_repo_revision: Optional[str] = None | |
"""The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" | |
hf_repo_url: Optional[str] = None | |
"""The url of the saved model in the Hugging Face Hub (will be autoset)""" | |
output_dir: str = "output" | |
"""Where to save the model""" | |
save_traces: bool = False | |
"""Whether to save learning data traces""" | |
cache_dataset_only: bool = False | |
"""Immediately exit after caching the dataset""" | |
# Ai2 specific settings | |
try_launch_beaker_eval_jobs_on_weka: bool = False | |
"""Whether to launch beaker evaluation jobs after training on weka""" | |
try_auto_save_to_beaker: bool = True | |
"""Whether to try to save the model to Beaker dataset `/output` after training""" | |
gs_bucket_path: Optional[str] = None | |
"""The path to the gs bucket to save the model to""" | |
oe_eval_tasks: Optional[List[str]] = None | |
"""The beaker evaluation tasks to launch""" | |
oe_eval_max_length: int = 4096 | |
"""the max generation length for evaluation for oe-eval""" | |
eval_priority: Literal["low", "normal", "high", "urgent"] = "normal" | |
"""the priority of auto-launched evaluation jobs""" | |
def __post_init__(self): | |
if self.single_gpu_mode: | |
self.vllm_gpu_memory_utilization = 0.3 | |
assert self.num_samples_per_prompt_rollout > 1, "Number of samples per prompt must be greater than 1 for GRPO!" | |
assert ( | |
self.apply_verifiable_reward or self.apply_r1_style_format_reward or self.non_stop_penalty | |
), "At least one reward must be applied!" | |
assert ( | |
self.pack_length >= self.max_prompt_token_length + self.response_length | |
), "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" | |
def get_train_ds_config( | |
offload, | |
adam_offload=False, | |
stage=0, | |
bf16=True, | |
max_norm=1.0, | |
zpg=8, | |
grad_accum_dtype=None, | |
disable_trace_cache=True, | |
): | |
device = "cpu" if offload else "none" | |
zero_opt_dict = { | |
"stage": stage, | |
"offload_param": {"device": device}, | |
"offload_optimizer": { | |
"device": "cpu" if adam_offload else "none", | |
"pin_memory": True, | |
}, | |
"sub_group_size": "auto", | |
"stage3_max_live_parameters": "auto", | |
"stage3_max_reuse_distance": "auto", | |
"stage3_param_persistence_threshold": "auto", | |
"stage3_prefetch_bucket_size": "auto", | |
"reduce_bucket_size": "auto", | |
# # ZeRO++ | |
# "zero_hpz_partition_size": zpg, | |
# "zero_quantized_weights": False, | |
# "zero_quantized_gradients": False, | |
} | |
if disable_trace_cache: | |
zero_opt_dict["stage3_prefetch_bucket_size"] = 0 | |
zero_opt_dict["stage3_max_live_parameters"] = 0 | |
zero_opt_dict["stage3_max_reuse_distance"] = 0 | |
return { | |
"steps_per_print": 100, | |
"zero_optimization": zero_opt_dict, | |
"bf16": { | |
"enabled": bf16, | |
}, | |
"gradient_clipping": max_norm, | |
"prescale_gradients": False, | |
"wall_clock_breakdown": False, | |
"data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"}, | |
} | |
def get_eval_ds_config( | |
offload, | |
stage=0, | |
bf16=True, | |
): | |
zero_opt_dict = { | |
"stage": stage, | |
"stage3_param_persistence_threshold": "auto", | |
"offload_param": { | |
"device": "cpu" if offload else "none", | |
"pin_memory": True, | |
}, | |
} | |
return { | |
"steps_per_print": 100, | |
"zero_optimization": zero_opt_dict, | |
"bf16": { | |
"enabled": bf16, | |
}, | |
"prescale_gradients": False, | |
"wall_clock_breakdown": False, | |
} | |
def get_optimizer_grouped_parameters( | |
model: torch.nn.Module, | |
weight_decay: float, | |
no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], | |
): | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [ | |
p | |
for n, p in model.named_parameters() | |
if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) | |
], | |
"weight_decay": weight_decay, | |
}, | |
{ | |
"params": [ | |
p | |
for n, p in model.named_parameters() | |
if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) | |
], | |
"weight_decay": 0.0, | |
}, | |
] | |
return optimizer_grouped_parameters | |
def _z3_params_to_fetch(param_list): | |
return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] | |
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor: | |
"""Compute mean of tensor with a masked values.""" | |
if axis is not None: | |
return ((values * mask).sum(axis=axis) / mask.sum(axis=axis)).mean() | |
else: | |
return (values * mask).sum() / mask.sum() | |
def masked_sum( | |
values: torch.Tensor, | |
mask: torch.Tensor, | |
axis: Optional[bool] = None, | |
constant_normalizer: float = 1.0, | |
) -> torch.Tensor: | |
"""Compute sum of tensor with a masked values. Use a constant to normalize.""" | |
if axis is not None: | |
return (values * mask).sum(axis=axis) / constant_normalizer | |
else: | |
return (values * mask).sum() / constant_normalizer | |
class MetricsTracker: | |
"""A simple class to prellocate all metrics in an array | |
so we can do only one allreduce operation to get the metrics mean""" | |
def __init__(self, max_metrics: int = 32, device: torch.device = torch.device("cuda")): | |
self.metrics = torch.zeros(max_metrics, device=device) | |
self.names2idx = {} | |
self.current_idx = 0 | |
self.max_metrics = max_metrics | |
def add(self, name: str, value: torch.tensor): | |
if name not in self.names2idx: | |
if self.current_idx >= self.max_metrics: | |
raise ValueError(f"Exceeded maximum number of metrics ({self.max_metrics})") | |
self.names2idx[name] = self.current_idx | |
self.current_idx += 1 | |
self.metrics[self.names2idx[name]] = value | |
return self | |
def get_metrics_list(self) -> dict[str, float]: | |
metrics_list = self.metrics.tolist() | |
return {name: metrics_list[idx] for name, idx in self.names2idx.items()} | |
def collate_fn(tensors_list: List[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: | |
padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) | |
if pin_memory: | |
padded_tensor = padded_tensor.pin_memory() | |
return padded_tensor | |
def to_device_inplace(tensors_list: List[torch.Tensor], device: torch.device): | |
for i in range(len(tensors_list)): | |
tensors_list[i] = tensors_list[i].to(device, non_blocking=True) | |
class Timer: | |
"""A context manager for timing code blocks""" | |
def __init__(self, description: str, noop: int = 0): | |
self.description = description | |
self.noop = noop | |
def __enter__(self): | |
if self.noop: | |
return | |
self.start_time = time.perf_counter() | |
return self | |
def __exit__(self, type, value, traceback): | |
if self.noop: | |
return | |
self.end_time = time.perf_counter() | |
self.duration = self.end_time - self.start_time | |
print(f"{self.description}: {self.duration:.2f} seconds") | |
class ShufflingIterator: | |
def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None): | |
self.data = data.copy() | |
self.batch_size = batch_size | |
self.index = 0 | |
self.rng = np.random.default_rng(seed) | |
self.rng.shuffle(self.data) | |
# Ensure the effective dataset size is divisible by batch_size | |
self.effective_size = len(self.data) - (len(self.data) % batch_size) | |
def __iter__(self) -> Iterator[List[int]]: | |
return self | |
def __next__(self) -> List[int]: | |
if self.index >= self.effective_size: | |
self.index = 0 | |
self.rng.shuffle(self.data) | |
end_index = self.index + self.batch_size | |
batch = self.data[self.index : end_index].tolist() | |
self.index = end_index | |
return batch | |
class RayProcess: | |
def __init__(self, world_size, rank, local_rank, master_addr, master_port): | |
logging.basicConfig( | |
format="%(asctime)s %(levelname)-8s %(message)s", | |
level=logging.INFO, | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
self.world_size = world_size | |
self.rank = rank | |
self.local_rank = local_rank | |
self.master_addr = master_addr if master_addr else self.get_current_node_ip() | |
self.master_port = master_port if master_port else self.get_free_port() | |
os.environ["MASTER_ADDR"] = self.master_addr | |
os.environ["MASTER_PORT"] = str(self.master_port) | |
os.environ["WORLD_SIZE"] = str(self.world_size) | |
os.environ["RANK"] = str(self.rank) | |
# NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES | |
# environment variable for each actor, so always set device to 0 | |
# os.environ["LOCAL_RANK"] = str(self._local_rank) | |
os.environ["LOCAL_RANK"] = "0" | |
random.seed(self.rank) | |
np.random.seed(self.rank) | |
torch.manual_seed(self.rank) | |
@staticmethod | |
def get_current_node_ip(): | |
address = ray._private.services.get_node_ip_address() | |
# strip ipv6 address | |
return address.strip("[]") | |
@staticmethod | |
def get_free_port(): | |
with socket.socket() as sock: | |
sock.bind(("", 0)) | |
return sock.getsockname()[1] | |
def get_master_addr_port(self): | |
return self.master_addr, self.master_port | |
def empty_cache(self) -> None: | |
torch.cuda.empty_cache() | |
@ray.remote(num_gpus=1) | |
class PolicyTrainerRayProcess(RayProcess): | |
def from_pretrained( | |
self, | |
args: Args, | |
model_config: ModelConfig, | |
beaker_config: BeakerRuntimeConfig, | |
wandb_url: str, | |
tokenizer: PreTrainedTokenizer, | |
): | |
self.args = args | |
self.tokenizer = tokenizer | |
self.model_config = model_config | |
self.beaker_config = beaker_config | |
self.wandb_url = wandb_url | |
torch.cuda.set_device(self.local_rank) | |
self.device = torch.device(self.local_rank) | |
deepspeed.init_distributed() | |
ds_config = get_train_ds_config( | |
offload=False, | |
adam_offload=False, | |
stage=args.deepspeed_stage, | |
bf16=True, | |
) | |
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size | |
ds_config["gradient_accumulation_steps"] = 1 | |
# @vwxyzjn: MAGIC: it's actually needed to initialize this `dschf`, so | |
# https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration | |
# next line instructs transformers to partition the model directly over multiple gpus using | |
# deepspeed.zero.Init when model's `from_pretrained` method is called. | |
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: | |
HfDeepSpeedConfig(ds_config) | |
else: | |
pass | |
self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( | |
model_config.model_name_or_path, | |
revision=model_config.model_revision, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2", | |
use_cache=False, | |
) | |
disable_dropout_in_model(self.policy) | |
self.policy.gradient_checkpointing_enable() | |
# AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam | |
# AdamOptimizer = FusedAdam | |
if args.set_weight_decay_on_bias_and_norm: | |
optim_params = get_optimizer_grouped_parameters(self.policy, args.weight_decay) | |
else: | |
optim_params = self.policy.parameters() | |
# self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) | |
self.optimizer = torch.optim.AdamW(optim_params, lr=args.learning_rate, fused=args.fused_optimizer) | |
num_scheduler_steps = args.num_training_steps * args.num_epochs * args.num_mini_batches | |
warm_up_steps = args.warm_up_steps | |
if args.warmup_ratio > 0.0: | |
warm_up_steps = int(num_scheduler_steps * args.warmup_ratio) | |
scheduler = get_scheduler( | |
args.lr_scheduler_type, | |
optimizer=self.optimizer, | |
num_warmup_steps=warm_up_steps, | |
num_training_steps=num_scheduler_steps, | |
) | |
self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( | |
model=self.policy, | |
optimizer=self.optimizer, | |
config=ds_config, | |
lr_scheduler=scheduler, | |
dist_init_required=True, | |
) | |
self.model.train() | |
# reference model | |
ds_config = get_eval_ds_config( | |
offload=False, | |
# inference model only has stage 3 (sharding) or stage 0 (no sharding) | |
# stage 2 is optimizer sharding which doesn't apply to inference | |
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0, | |
bf16=True, | |
) | |
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size | |
ds_config["gradient_accumulation_steps"] = 1 | |
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: | |
HfDeepSpeedConfig(ds_config) | |
else: | |
pass | |
self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( | |
model_config.model_name_or_path, | |
revision=model_config.model_revision, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2", | |
use_cache=False, | |
) | |
disable_dropout_in_model(self.ref_policy) | |
self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config) | |
self.ref_policy.eval() | |
self.local_metrics = MetricsTracker(max_metrics=32, device=self.device) | |
def forward( | |
self, | |
model: PreTrainedModel, | |
query_response: torch.LongTensor, | |
attention_mask: torch.LongTensor, | |
position_ids: torch.LongTensor, | |
pad_token_id: int, | |
temperature: float, | |
) -> torch.Tensor: | |
# Replace pad tokens with 0s so that we don't run into index out of bounds errors | |
padding_mask = query_response != pad_token_id | |
input_ids = torch.masked_fill(query_response, ~padding_mask, 0) | |
# NOTE: the [:-1] and [1:] are because the logits and generated tokens are off by 1 in index | |
output = model( | |
input_ids=input_ids[:, :-1], | |
# @vwxyzjn: without clamp, we get index out of bounds errors; TODO: investigate | |
attention_mask=attention_mask[:, :-1].clamp(0, 1), | |
position_ids=position_ids[:, :-1], | |
return_dict=True, | |
) | |
logits = output.logits | |
logits /= temperature + 1e-7 | |
logprob = log_softmax_and_gather(logits, input_ids[:, 1:]) | |
return logprob | |
def setup_model_update_group(self, vllm_engines): | |
self.vllm_engines = vllm_engines | |
if self.rank == 0: | |
master_address = ray._private.services.get_node_ip_address() | |
with socket.socket() as sock: | |
sock.bind(("", 0)) | |
master_port = sock.getsockname()[1] | |
vllm_num_engines, vllm_tensor_parallel_size = ( | |
self.args.vllm_num_engines, | |
self.args.vllm_tensor_parallel_size, | |
) | |
world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 | |
backend = self.args.vllm_sync_backend | |
refs = [ | |
engine.init_process_group.remote( | |
master_address, | |
master_port, | |
i * vllm_tensor_parallel_size + 1, | |
world_size, | |
"openrlhf", | |
backend=backend, | |
) | |
for i, engine in enumerate(vllm_engines) | |
] | |
self.model_update_group = init_process_group( | |
backend=backend, | |
init_method=f"tcp://{master_address}:{master_port}", | |
world_size=world_size, | |
rank=0, | |
group_name="openrlhf", | |
) | |
ray.get(refs) | |
torch.distributed.barrier() | |
def broadcast_to_vllm(self): | |
# avoid OOM | |
torch.cuda.empty_cache() | |
model = self.model.module | |
count, num_params = 0, len(list(model.named_parameters())) | |
refss = [] | |
if self.args.gather_whole_model: | |
with deepspeed.zero.GatheredParameters(model.parameters(), enabled=self.args.deepspeed_stage == 3): | |
for name, param in model.named_parameters(): | |
count += 1 # empty_cache at last param | |
# Fire all vllm engines for broadcast | |
if torch.distributed.get_rank() == 0: | |
shape = param.shape if self.args.deepspeed_stage != 3 else param.ds_shape | |
refs = [ | |
engine.update_weight.remote( | |
name, dtype=param.dtype, shape=shape, empty_cache=count == num_params | |
) | |
for engine in self.vllm_engines | |
] | |
refss.extend(refs) | |
if torch.distributed.get_rank() == 0: | |
torch.distributed.broadcast(param.data, 0, group=self.model_update_group) | |
else: # broadcast each parameter independently | |
for name, param in model.named_parameters(): | |
count += 1 | |
if torch.distributed.get_rank() == 0: | |
shape = param.shape if self.args.deepspeed_stage != 3 else param.ds_shape | |
refs = [ | |
engine.update_weight.remote( | |
name, dtype=param.dtype, shape=shape, empty_cache=count == num_params | |
) | |
for engine in self.vllm_engines | |
] | |
refss.extend(refs) | |
with deepspeed.zero.GatheredParameters([param], enabled=self.args.deepspeed_stage == 3): | |
if torch.distributed.get_rank() == 0: | |
torch.distributed.broadcast(param.data, 0, group=self.model_update_group) | |
if torch.distributed.get_rank() == 0: | |
ray.get(refss) | |
def train( | |
self, | |
collated_query_responses, | |
collated_attention_masks, | |
collated_position_ids, | |
collated_advantages, | |
collated_response_masks, | |
pad_token_id: int, | |
num_mini_batches: int, | |
): | |
args = self.args | |
to_device_inplace(collated_query_responses, self.device) | |
to_device_inplace(collated_attention_masks, self.device) | |
to_device_inplace(collated_position_ids, self.device) | |
to_device_inplace(collated_advantages, self.device) | |
to_device_inplace(collated_response_masks, self.device) | |
accumulation_steps = len(collated_query_responses) // (num_mini_batches) | |
# Calculate the logprob of the reference policy | |
collated_ref_logprobs = [] | |
with Timer("Inference Calculation", noop=self.rank != 0): | |
with torch.no_grad(): | |
for i in range(len(collated_query_responses)): | |
query_response = collated_query_responses[i] | |
attention_mask = collated_attention_masks[i] | |
position_id = collated_position_ids[i] | |
response_mask = collated_response_masks[i] | |
ref_logprob = self.forward( | |
self.ref_policy, | |
query_response, | |
attention_mask, | |
position_id, | |
pad_token_id, | |
args.temperature, | |
) | |
ref_logprob = torch.masked_fill(ref_logprob, ~response_mask[:, 1:].bool(), INVALID_LOGPROB) | |
collated_ref_logprobs.append(ref_logprob) | |
torch.cuda.empty_cache() | |
local_step = 0 | |
# Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch | |
with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): | |
old_logprobs = [None for _ in range(len(collated_query_responses))] | |
kl1_stats = torch.zeros(len(collated_query_responses)) | |
kl2_stats = torch.zeros(len(collated_query_responses)) | |
kl3_stats = torch.zeros(len(collated_query_responses)) | |
kl4_stats = torch.zeros(len(collated_query_responses)) | |
pg_clipfrac_stats = torch.zeros(len(collated_query_responses)) | |
pg_loss_stats = torch.zeros(len(collated_query_responses)) | |
loss_stats = torch.zeros(len(collated_query_responses)) | |
ratio_stats = torch.zeros(len(collated_query_responses)) | |
for epoch_idx in range(args.num_epochs): | |
for i in range(len(collated_query_responses)): | |
mb_ref_logprob = collated_ref_logprobs[i] | |
mb_query_responses = collated_query_responses[i] | |
mb_advantages = collated_advantages[i] | |
mb_response_masks = collated_response_masks[i] | |
mb_response_masks_bool = mb_response_masks[:, 1:].bool() | |
mb_attention_mask = collated_attention_masks[i] | |
mb_position_id = collated_position_ids[i] | |
mb_new_logprobs = self.forward( | |
self.model, | |
mb_query_responses, | |
mb_attention_mask, | |
mb_position_id, | |
pad_token_id, | |
args.temperature, | |
) | |
mb_new_logprobs = torch.masked_fill(mb_new_logprobs, ~mb_response_masks_bool, INVALID_LOGPROB) | |
# Cache the old logprobs | |
with torch.no_grad(): | |
if epoch_idx == 0: | |
old_logprobs[i] = mb_new_logprobs | |
mb_old_logprobs = old_logprobs[i].detach() | |
# Calculate the policy's loss | |
logprobs_diff = mb_new_logprobs - mb_old_logprobs | |
ratio = torch.exp(logprobs_diff) | |
pg_losses = -mb_advantages[:, 1:] * ratio | |
pg_losses2 = -mb_advantages[:, 1:] * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) | |
pg_loss_max = torch.max(pg_losses, pg_losses2) | |
# Here we recalculate kl: we want the KL loss to backpropagate through the model | |
# We also clamp the KL loss to avoid numerical instability | |
# https://chatgpt.com/share/679d0ed9-8f48-8011-926e-e274b15ae8ae | |
ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0) | |
kl1 = ref_logprobs_diff | |
kl2 = (ref_logprobs_diff) ** 2 / 2 | |
kl3 = torch.expm1(-ref_logprobs_diff) + ref_logprobs_diff # this is more numerically stable | |
kl4 = ratio * ref_logprobs_diff | |
if args.kl_estimator == "kl1": | |
kl = kl1 | |
elif args.kl_estimator == "kl2": | |
kl = kl2 | |
elif args.kl_estimator == "kl3": | |
kl = kl3 | |
elif args.kl_estimator == "kl4": | |
kl = kl4 | |
# grpo change: directly subtract KL in loss (add) | |
if args.dr_grpo: | |
loss = masked_sum(pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis, constant_normalizer=args.response_length) | |
else: | |
loss = masked_mean(pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis) | |
loss = loss / accumulation_steps | |
self.model.backward(loss) | |
if (local_step + 1) % accumulation_steps == 0: | |
self.model.step() | |
local_step += 1 | |
with torch.no_grad(): | |
# NOTE: in packed implementation, kl calculation are averages over response tokens | |
kl1_stats[i] = masked_mean(kl1, mb_response_masks_bool, args.masked_mean_axis).float() | |
kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, args.masked_mean_axis).float() | |
kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, args.masked_mean_axis).float() | |
kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, args.masked_mean_axis).float() | |
pg_clipfrac_stats[i] = masked_mean( | |
(pg_losses2 > pg_losses).float(), mb_response_masks_bool, args.masked_mean_axis | |
) | |
pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool, args.masked_mean_axis) | |
loss_stats[i] = loss | |
ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool, args.masked_mean_axis) | |
with torch.no_grad(): | |
self.local_metrics.add("objective/kl_avg", kl1_stats.mean()) | |
self.local_metrics.add("objective/kl2_avg", kl2_stats.mean()) | |
self.local_metrics.add("objective/kl3_avg", kl3_stats.mean()) | |
self.local_metrics.add("objective/kl4_avg", kl4_stats.mean()) | |
self.local_metrics.add("loss/policy_avg", pg_loss_stats.mean()) | |
self.local_metrics.add("loss/policy_avg", loss_stats.mean()) | |
self.local_metrics.add("policy/clipfrac_avg", pg_clipfrac_stats.mean()) | |
self.local_metrics.add("val/ratio", ratio_stats.mean()) | |
self.local_metrics.add("val/ratio_var", ratio_stats.var()) | |
self.local_metrics.add("lr", self.scheduler.get_last_lr()[0]) | |
return self.local_metrics.get_metrics_list() | |
def save_model(self, output_dir: str) -> None: | |
model_to_save = self.model | |
if self.rank == 0: | |
os.makedirs(output_dir, exist_ok=True) | |
# save model weights for ZeRO2/3 | |
if hasattr(model_to_save, "module"): | |
model_to_save = model_to_save.module | |
# gather parameters | |
output_state_dict = {} | |
for k, v in model_to_save.named_parameters(): | |
# only gather z3 params | |
params_to_fetch = _z3_params_to_fetch([v]) | |
with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): | |
vv = v.data.cpu() | |
if self.rank == 0: | |
output_state_dict[k] = vv | |
if self.rank == 0: | |
state_dict = model_to_save.state_dict() | |
# copy named_buffers with `persistent=True` | |
for k, v in model_to_save.named_buffers(): | |
if k not in state_dict: | |
continue | |
vv = v.data.cpu() | |
output_state_dict[k] = vv | |
state_dict_keys = set(state_dict.keys()) | |
output_state_dict_keys = set(output_state_dict.keys()) | |
# corner case for tie_word_embeddings, such as Qwen2-0.5B | |
if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: | |
state_dict_keys.remove("lm_head.weight") | |
assert state_dict_keys.issubset( | |
output_state_dict_keys | |
), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" | |
# only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 | |
if isinstance(model_to_save, PeftModel): | |
model_to_save.save_pretrained(output_dir) | |
if self.stage == 3: | |
torch.save( | |
get_peft_model_state_dict(model_to_save, output_state_dict), | |
os.path.join(output_dir, "adapter_model.bin"), | |
) | |
else: | |
model_to_save.save_pretrained(output_dir, state_dict=output_state_dict) | |
# save tokenizer | |
self.tokenizer.save_pretrained(output_dir) | |
# we need this because we don't know which node is rank 0 is on | |
def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url, training_step): | |
args = self.args | |
if self.rank == 0: | |
future = ( | |
ray.remote(launch_ai2_evals_on_weka) | |
.options(num_cpus=1) | |
.remote( | |
step_dir, | |
leaderboard_name, | |
args.oe_eval_max_length, | |
wandb_url, | |
training_step, | |
args.oe_eval_tasks, | |
args.stop_strings, | |
args.gs_bucket_path, | |
args.eval_priority, | |
) | |
) | |
else: | |
future = None | |
return future | |
class ModelGroup: | |
def __init__( | |
self, | |
pg: PlacementGroup, | |
ray_process_cls: RayProcess, | |
num_gpus_per_node: List[int], | |
single_gpu_mode: bool, | |
): | |
self.pg = pg | |
self.ray_process_cls = ray_process_cls | |
self.num_gpus_per_node = num_gpus_per_node | |
self.num_gpus_per_actor = 0.48 if single_gpu_mode else 1 | |
self.num_cpus_per_actor = 4 | |
self.models = [] | |
world_size = sum(self.num_gpus_per_node) | |
master_policy = ray_process_cls.options( | |
num_cpus=self.num_cpus_per_actor, | |
num_gpus=self.num_gpus_per_actor, | |
scheduling_strategy=PlacementGroupSchedulingStrategy( | |
placement_group=self.pg, placement_group_bundle_index=0 | |
), | |
).remote(world_size, 0, 0, None, None) | |
self.models.append(master_policy) | |
master_addr, master_port = ray.get(master_policy.get_master_addr_port.remote()) | |
def get_bundle_index(rank, num_gpus_per_node): | |
"""given a rank and a list of num_gpus_per_node, return the index of the bundle that the rank belongs to""" | |
bundle_idx = 0 | |
while rank >= num_gpus_per_node[bundle_idx]: | |
rank -= num_gpus_per_node[bundle_idx] | |
bundle_idx += 1 | |
return bundle_idx | |
assert get_bundle_index(0, [7, 8, 4]) == 0 | |
assert get_bundle_index(1, [7, 8, 4]) == 0 | |
assert get_bundle_index(7, [7, 8, 4]) == 1 | |
assert get_bundle_index(8, [7, 8, 4]) == 1 | |
assert get_bundle_index(9, [7, 8, 4]) == 1 | |
assert get_bundle_index(16, [7, 8, 4]) == 2 | |
# Setup worker models | |
for rank in range(1, world_size): | |
print(f"{rank=}, {world_size=}, {rank=}, {master_addr=}, {master_port=}") | |
scheduling_strategy = PlacementGroupSchedulingStrategy( | |
placement_group=self.pg, | |
placement_group_bundle_index=get_bundle_index(rank, self.num_gpus_per_node), | |
) | |
worker_policy = ray_process_cls.options( | |
num_cpus=self.num_cpus_per_actor, | |
num_gpus=self.num_gpus_per_actor, | |
scheduling_strategy=scheduling_strategy, | |
).remote(world_size, rank, 0, master_addr, master_port) | |
self.models.append(worker_policy) | |
def vllm_generate_thread( | |
vllm_engines: List[ray.actor.ActorHandle], | |
generation_config: SamplingParams, | |
eval_generation_config: SamplingParams, | |
inference_results_Q: Queue, | |
param_prompt_Q: Queue, | |
num_training_steps: int, | |
eval_prompt_token_ids: Optional[List[int]], | |
evaluation_inference_results_Q: Queue, | |
eval_freq: int, | |
resume_training_step: int = 1, | |
): | |
def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingParams): | |
# Split queries between engines | |
queries_per_engine = (len(prompts) + len(vllm_engines) - 1) // len(vllm_engines) | |
split_queries = [prompts[i : i + queries_per_engine] for i in range(0, len(prompts), queries_per_engine)] | |
# Generate responses in parallel across engines | |
futures = [ | |
vllm_engine.generate.remote(sampling_params=sampling_params, prompt_token_ids=queries, use_tqdm=False) | |
for vllm_engine, queries in zip(vllm_engines, split_queries) | |
] | |
# Gather all responses | |
all_outputs = ray.get(futures) | |
response_ids = [] | |
finish_reasons = [] # either "stop" or "length" | |
for outputs in all_outputs: | |
response_ids.extend([list(out.token_ids) for output in outputs for out in output.outputs]) | |
finish_reasons.extend([out.finish_reason for output in outputs for out in output.outputs]) | |
return response_ids, finish_reasons | |
for training_step in range(resume_training_step, num_training_steps + 1): | |
items = param_prompt_Q.get() | |
if items is None: | |
break | |
_, g_queries_list = items | |
with Timer("🔥 Generation time"): | |
response_ids, finish_reasons = generate_with_engines(g_queries_list, generation_config) | |
inference_results_Q.put((response_ids, finish_reasons)) | |
# Evaluate the model | |
if eval_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0: | |
response_ids, finish_reasons = generate_with_engines(eval_prompt_token_ids, eval_generation_config) | |
evaluation_inference_results_Q.put((response_ids, finish_reasons)) | |
def data_preparation_thread( | |
reward_fn: Callable, | |
inference_results_Q: Queue, | |
packed_sequences_Q: Queue, | |
queries_prompt_Q: Queue, | |
args: Args, | |
tokenizer: PreTrainedTokenizer, | |
num_training_steps: int, | |
): | |
for training_step in range(1, num_training_steps + 1): | |
# Get next batch of prompts and responses | |
items = queries_prompt_Q.get() | |
queries, ground_truths, datasets = items | |
# ------------------------------------------------------------------------------------------------ | |
# Pack sequences | |
if args.num_samples_per_prompt_rollout > 1: | |
queries = [item for item in queries for _ in range(args.num_samples_per_prompt_rollout)] | |
ground_truths = [item for item in ground_truths for _ in range(args.num_samples_per_prompt_rollout)] | |
datasets = [item for item in datasets for _ in range(args.num_samples_per_prompt_rollout)] | |
with Timer("🚀 [Data Preparation Thread] Getting response ids"): | |
responses, finish_reasons = inference_results_Q.get() | |
for i in range(len(finish_reasons)): | |
if finish_reasons[i] == "stop" and responses[i][-1] != tokenizer.eos_token_id: | |
responses[i].append(tokenizer.eos_token_id) | |
with Timer("📦 [Data Preparation Thread] Packing sequences"): | |
packed_sequences = pack_sequences( | |
queries=queries, | |
responses=responses, | |
pack_length=args.pack_length, | |
pad_token_id=tokenizer.pad_token_id, | |
) | |
num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses) | |
with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True): | |
decoded_responses = tokenizer.batch_decode(responses, skip_special_tokens=True) | |
stop_rate = sum(int(finish_reason == "stop") for finish_reason in finish_reasons) / len(finish_reasons) | |
with Timer("💰 [Data Preparation Thread] Calculating rewards"): | |
scores, reward_metrics = reward_fn(responses, decoded_responses, ground_truths, datasets, finish_reasons) | |
with Timer("🎆 [Data Preparation Thread] Calculating advantages"): | |
# Calculate advantages | |
scores = np.array(scores) | |
print(f"[Data Preparation Thread] {len(scores)=}") | |
scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout) | |
global_mean_grouped_rewards = scores_per_prompt.mean(axis=-1) | |
global_mean_grouped_rewards = np.repeat( | |
global_mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0 | |
) | |
if args.dr_grpo: | |
global_advantages = (scores - global_mean_grouped_rewards) | |
else: | |
global_std_grouped_rewards = scores_per_prompt.std(axis=-1) | |
global_std_grouped_rewards = np.repeat( | |
global_std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0 | |
) | |
global_advantages = (scores - global_mean_grouped_rewards) / (global_std_grouped_rewards + 1e-8) | |
# Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value | |
# and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed) | |
lookup_advantages = np.zeros(len(global_advantages) + 1, dtype=np.float32) | |
lookup_advantages[1:] = global_advantages | |
packed_advantages = [ | |
torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) | |
for packed_mask in packed_sequences.response_masks | |
] | |
packed_sequences.advantages = packed_advantages | |
with Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker"): | |
B = ( | |
len(packed_sequences.query_responses) // args.world_size | |
) # essentially doing `drop_last=True`, which is fine. | |
collated_data = [] | |
for i in range(args.world_size): | |
per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)] | |
per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)] | |
per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)] | |
per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)] | |
per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)] | |
# Shuffle the batch and collate the data | |
b_inds = np.random.permutation(len(per_device_packed_query_responses)) | |
collated_query_responses = [] | |
collated_attention_masks = [] | |
collated_position_ids = [] | |
collated_response_masks = [] | |
collated_advantages = [] | |
for j in range(0, len(per_device_packed_query_responses), args.per_device_train_batch_size): | |
micro_range = b_inds[j : j + args.per_device_train_batch_size] | |
collated_query_responses.append( | |
collate_fn( | |
[per_device_packed_query_responses[idx] for idx in micro_range], tokenizer.pad_token_id | |
) | |
) | |
collated_attention_masks.append( | |
collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0) | |
) | |
collated_position_ids.append( | |
collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0) | |
) | |
collated_response_masks.append( | |
collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0) | |
) | |
collated_advantages.append( | |
collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0) | |
) | |
collated_data.append( | |
{ | |
"collated_query_responses": collated_query_responses, | |
"collated_attention_masks": collated_attention_masks, | |
"collated_position_ids": collated_position_ids, | |
"collated_advantages": collated_advantages, | |
"collated_response_masks": collated_response_masks, | |
} | |
) | |
# Create a result package with metrics and data | |
sequence_lengths = np.array([len(response) for response in responses]) | |
metrics = { | |
"scores": np.array(scores).mean(), | |
"val/sequence_lengths": sequence_lengths.mean(), | |
"val/sequence_lengths_min": sequence_lengths.min(), | |
"val/sequence_lengths_max": sequence_lengths.max(), | |
"val/stop_rate": stop_rate, | |
**reward_metrics, | |
} | |
if args.save_traces: | |
traces = { | |
"scores": scores.tolist(), | |
"finish_reasons": finish_reasons, | |
"responses": responses, | |
"queries": queries, | |
"ground_truths": ground_truths, | |
"datasets": datasets, | |
"training_step": training_step, | |
**reward_metrics, | |
} | |
os.makedirs(args.output_dir, exist_ok=True) | |
with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f: | |
json.dump(traces, f) | |
f.write("\n") | |
# Put the packed sequences and metrics into the output queue | |
packed_sequences_Q.put( | |
{ | |
"packed_sequences": packed_sequences, # for debugging purposes | |
"collated_data": collated_data, | |
"metrics": metrics, | |
"responses_count": len(responses), | |
"num_new_tokens": num_new_tokens, | |
"B": B, | |
} | |
) | |
def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn: Callable): | |
# ------------------------------------------------------------ | |
# Setup tokenizer | |
tc.tokenizer_revision = model_config.model_revision if tc.tokenizer_revision is None else tc.tokenizer_revision | |
tc.tokenizer_name_or_path = ( | |
model_config.model_name_or_path if tc.tokenizer_name_or_path is None else tc.tokenizer_name_or_path | |
) | |
if ( | |
tc.tokenizer_revision != model_config.model_revision | |
and tc.tokenizer_name_or_path != model_config.model_name_or_path | |
): | |
# Warn user if tokenizer and model use different revisions; this is an unusual | |
# use case. | |
warning = f"""Requested tokenizer revision `{tc.tokenizer_revision=}` is different | |
from the model revision `{model_config.model_revision=}` or the tokenizer name `{tc.tokenizer_name_or_path=}` | |
is different from the model name `{model_config.model_name_or_path=}`.""" | |
print(warning) | |
tokenizer = tc.tokenizer | |
# ------------------------------------------------------------ | |
# Set up runtime variables | |
args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" | |
args.output_dir = os.path.join(args.output_dir, args.run_name) | |
args.dataset_local_cache_dir = os.path.abspath(args.dataset_local_cache_dir) | |
if is_beaker_job(): | |
args.dataset_local_cache_dir = "/weka/oe-adapt-default/allennlp/deletable_open_instruct_dataset_cache" | |
args.world_size = sum(args.num_learners_per_node) | |
args.num_training_steps = args.total_episodes // ( | |
args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout | |
) | |
args.eval_freq = max(1, args.num_training_steps // args.num_evals) | |
args.try_launch_beaker_eval_jobs_on_weka = args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job() | |
if args.push_to_hub: | |
if args.hf_repo_id is None: # auto-generate one | |
args.hf_repo_id = "open_instruct_dev" | |
if args.hf_entity is None: # first try to use AI2 entity | |
args.hf_entity = maybe_use_ai2_hf_entity() | |
if args.hf_entity is None: # then try to use the user's entity | |
args.hf_entity = HfApi().whoami()["name"] | |
args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" | |
if args.hf_repo_revision is None: # auto-generate one | |
args.hf_repo_revision = args.run_name | |
args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" | |
if args.with_tracking: | |
if args.wandb_entity is None: | |
args.wandb_entity = maybe_use_ai2_wandb_entity() | |
# ------------------------------------------------------------ | |
# Setup experiment tracking and seeds | |
all_configs = {} | |
beaker_config = None | |
if is_beaker_job(): | |
beaker_config = maybe_get_beaker_config() | |
all_configs.update(vars(beaker_config)) | |
all_configs.update(**asdict(args), **asdict(tc), **asdict(model_config)) | |
if args.with_tracking: | |
import wandb | |
wandb.init( | |
project=args.wandb_project_name, | |
entity=args.wandb_entity, | |
sync_tensorboard=True, | |
config=all_configs, | |
name=args.run_name, | |
save_code=True, | |
tags=[args.exp_name] + get_wandb_tags(), | |
) | |
writer = SummaryWriter(f"runs/{args.run_name}") | |
writer.add_text( | |
"hyperparameters", | |
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), | |
) | |
# ------------------------------------------------------------ | |
# Set up datasets | |
transform_fn_args = [ | |
{}, | |
{ | |
"max_token_length": args.max_token_length, | |
"max_prompt_token_length": args.max_prompt_token_length, | |
}, | |
] | |
train_dataset = get_cached_dataset_tulu( | |
dataset_mixer_list=args.dataset_mixer_list, | |
dataset_mixer_list_splits=args.dataset_mixer_list_splits, | |
tc=tc, | |
dataset_transform_fn=args.dataset_transform_fn, | |
transform_fn_args=transform_fn_args, | |
dataset_cache_mode=args.dataset_cache_mode, | |
dataset_config_hash=args.dataset_config_hash, | |
hf_entity=args.hf_entity, | |
dataset_local_cache_dir=args.dataset_local_cache_dir, | |
dataset_skip_cache=args.dataset_skip_cache, | |
) | |
train_dataset = train_dataset.shuffle(seed=args.seed) | |
eval_dataset = None | |
if len(args.dataset_mixer_eval_list) > 0: | |
eval_dataset = get_cached_dataset_tulu( | |
args.dataset_mixer_eval_list, | |
args.dataset_mixer_eval_list_splits, | |
tc, | |
args.dataset_transform_fn, | |
transform_fn_args, | |
hf_entity=args.hf_entity, | |
dataset_cache_mode=args.dataset_cache_mode, | |
dataset_config_hash=args.dataset_config_eval_hash, | |
dataset_local_cache_dir=args.dataset_local_cache_dir, | |
dataset_skip_cache=args.dataset_skip_cache, | |
) | |
eval_dataset = eval_dataset.shuffle(seed=args.seed) | |
if args.cache_dataset_only: | |
return | |
# ------------------------------------------------------------ | |
# Runtime setups and quick logging | |
pprint([args, model_config]) | |
visualize_token(train_dataset[0][INPUT_IDS_PROMPT_KEY], tokenizer) | |
# ------------------------------------------------------------ | |
# Create the model and optimizer | |
pg = None | |
bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node] | |
pg = placement_group(bundles, strategy="STRICT_SPREAD") | |
ray.get(pg.ready()) | |
inits = [] | |
policy_group = ModelGroup( | |
pg, | |
PolicyTrainerRayProcess, | |
args.num_learners_per_node, | |
args.single_gpu_mode, | |
) | |
wandb_url = wandb.run.get_url() if args.with_tracking else None | |
inits.extend( | |
model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) | |
for model in policy_group.models | |
) | |
max_len = args.max_prompt_token_length + args.response_length | |
vllm_engines = create_vllm_engines( | |
args.vllm_num_engines, | |
args.vllm_tensor_parallel_size, | |
args.vllm_enforce_eager, | |
model_config.model_name_or_path, | |
model_config.model_revision, | |
args.seed, | |
args.vllm_enable_prefix_caching, | |
max_len, | |
args.vllm_gpu_memory_utilization, | |
args.single_gpu_mode, | |
pg=pg if args.single_gpu_mode else None, | |
) | |
ray.get(inits) | |
print("======== ✅ all models and vLLM engines initialized =========") | |
ray.get([m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models]) | |
print("======== ✅ model update group setup successfully =========") | |
# Setup training | |
generation_config = SamplingParams( | |
temperature=args.temperature, | |
top_p=1.0, | |
max_tokens=args.response_length, | |
include_stop_str_in_output=True, | |
n=args.num_samples_per_prompt_rollout, | |
stop=args.stop_strings, | |
) | |
eval_generation_config = SamplingParams( | |
temperature=0.0, | |
top_p=1.0, | |
max_tokens=args.response_length, | |
include_stop_str_in_output=True, | |
n=1, # since we are doing greedy sampling, don't need to generate more | |
stop=args.stop_strings, | |
) | |
train_dataset_idxs = np.arange(len(train_dataset)) | |
iter_dataloader = ShufflingIterator(train_dataset_idxs, args.num_unique_prompts_rollout, seed=args.seed) | |
inference_results_Q = Queue(maxsize=1) | |
param_prompt_Q = Queue(maxsize=1) | |
evaluation_inference_results_Q = Queue(maxsize=1) | |
packed_sequences_Q = Queue(maxsize=1) | |
queries_prompt_Q = Queue(maxsize=1) | |
num_eval_samples = 32 | |
eval_prompt_token_ids = None | |
eval_ground_truths = None | |
if eval_dataset is not None: | |
eval_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY] | |
eval_ground_truths = eval_dataset[:num_eval_samples][GROUND_TRUTHS_KEY] | |
resume_training_step = 1 | |
thread = threading.Thread( | |
target=vllm_generate_thread, | |
args=( | |
vllm_engines, | |
generation_config, | |
eval_generation_config, | |
inference_results_Q, | |
param_prompt_Q, | |
args.num_training_steps, | |
eval_prompt_token_ids, | |
evaluation_inference_results_Q, | |
args.eval_freq, | |
resume_training_step, | |
), | |
) | |
thread.start() | |
print("======== ✅ vllm generate thread starts =========") | |
packing_thread = threading.Thread( | |
target=data_preparation_thread, | |
args=( | |
reward_fn, | |
inference_results_Q, | |
packed_sequences_Q, | |
queries_prompt_Q, | |
args, | |
tokenizer, | |
args.num_training_steps, | |
), | |
) | |
packing_thread.start() | |
print("======== ✅ data preparation thread starts =========") | |
# Send initial data to both threads | |
data_next = train_dataset[next(iter_dataloader)] | |
queries_next = data_next[INPUT_IDS_PROMPT_KEY] | |
ground_truths_next = data_next[GROUND_TRUTHS_KEY] | |
datasets_next = data_next[DATASET_SOURCE_KEY] | |
queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next)) | |
param_prompt_Q.put((None, queries_next)) | |
episode = 0 | |
num_total_tokens = 0 | |
start_time = time.time() | |
eval_futures = [] | |
try: | |
for training_step in range(resume_training_step, args.num_training_steps + 1): | |
print("-" * 100) | |
episode += ( | |
args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout | |
) # each sample is an episode | |
# ------------------------------------------------------------------------------------------------ | |
# Optionally evaluate the model | |
try: | |
evaluation_responses, _ = evaluation_inference_results_Q.get(timeout=0.01) | |
print("[Main Thread] 📊 Evaluation responses received") | |
table = {} | |
table["prompt"] = tokenizer.batch_decode(eval_prompt_token_ids) | |
table["response"] = tokenizer.batch_decode(evaluation_responses) | |
table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]] | |
table["ground_truth"] = eval_ground_truths | |
df = pd.DataFrame(table) | |
if args.with_tracking: | |
wandb.log({"sample_completions": wandb.Table(dataframe=df)}) | |
else: | |
print_rich_table(df.iloc[:1]) | |
del table | |
except Empty: | |
print("[Main Thread] 🙈 Evaluation responses not received") | |
# ------------------------------------------------------------------------------------------------ | |
# Sync weights and send the next batch of prompts to vLLM | |
if args.async_mode: | |
if training_step != 1: | |
data_next = train_dataset[next(iter_dataloader)] | |
queries_next = data_next[INPUT_IDS_PROMPT_KEY] | |
ground_truths_next = data_next[GROUND_TRUTHS_KEY] | |
datasets_next = data_next[DATASET_SOURCE_KEY] | |
with Timer("[Main Thread] 🔄 Loading weights using shared memory"): | |
ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models]) | |
queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next)) | |
param_prompt_Q.put((None, queries_next)) | |
else: | |
if training_step != 1: | |
# NOTE: important: the indent here is different for sync mode | |
# we also set to use `queries = queries_next` immediately | |
data_next = train_dataset[next(iter_dataloader)] | |
queries_next = data_next[INPUT_IDS_PROMPT_KEY] | |
ground_truths_next = data_next[GROUND_TRUTHS_KEY] | |
datasets_next = data_next[DATASET_SOURCE_KEY] | |
with Timer("🔄 Loading weights using shared memory"): | |
ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models]) | |
queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next)) | |
param_prompt_Q.put((None, queries_next)) | |
# ------------------------------------------------------------------------------------------------ | |
# Get the packed sequences with advantages from the packing thread | |
with Timer("[Main Thread] 📦 Getting packed sequences from thread"): | |
packed_data = packed_sequences_Q.get() | |
packed_sequences = packed_data["packed_sequences"] | |
data_thread_metrics = packed_data["metrics"] | |
B = packed_data["B"] | |
collated_data = packed_data["collated_data"] | |
num_total_tokens += packed_data["num_new_tokens"] | |
# Log info about the packed sequences | |
print( | |
f"Number of training examples per device: {B=}, packed sequence fraction of original sequences: {len(packed_sequences.query_responses) / packed_data['responses_count']}" | |
) | |
if B == 0: | |
print("[Main Thread] 🤡 After packing, there is not enough data to train") | |
continue | |
# ------------------------------------------------------------------------------------------------ | |
# Train the model | |
with Timer("[Main Thread] 🗡️ Training"): | |
metrics_list: List[dict[str, float]] = ray.get( | |
[ | |
policy_group.models[i].train.remote( | |
**collated_data[i], | |
pad_token_id=tokenizer.pad_token_id, | |
num_mini_batches=args.num_mini_batches, | |
) | |
for i in range(args.world_size) | |
] | |
) | |
average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]} | |
metrics = { | |
"episode": episode, | |
"training_step": training_step, | |
"val/num_total_tokens": num_total_tokens, | |
"epoch": episode / args.num_samples_per_prompt_rollout / len(train_dataset), | |
"tokens_per_second": num_total_tokens / (time.time() - start_time), | |
**data_thread_metrics, | |
**average_metrics, | |
} | |
print_rich_single_line_metrics(metrics) | |
for key, value in metrics.items(): | |
writer.add_scalar(key, value, episode) | |
if args.save_freq > 0 and training_step % args.save_freq == 0: | |
with Timer("[Main Thread] 🗡️ Saving model"): | |
checkpoint_dir = f"{args.output_dir}_checkpoints" | |
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") | |
print(f"Saving model at step {training_step} to {step_dir}") | |
ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)]) | |
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job(): | |
leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" | |
eval_futures.extend( | |
[ | |
policy_group.models[i].launch_ai2_evals_on_weka_wrapper.remote( | |
step_dir, leaderboard_name, wandb_url, training_step | |
) | |
for i in range(args.world_size) | |
] | |
) | |
print(f"Saving final model at step {training_step} to {args.output_dir}") | |
with Timer("[Main Thread] 🗡️ Saving model"): | |
ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)]) | |
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job(): | |
leaderboard_name = args.hf_repo_revision | |
eval_futures.extend( | |
[ | |
policy_group.models[i].launch_ai2_evals_on_weka_wrapper.remote( | |
args.output_dir, leaderboard_name, wandb_url, training_step | |
) | |
for i in range(args.world_size) | |
] | |
) | |
except Exception as e: | |
print(f"Training error occurred: {str(e)}") | |
print(traceback.format_exc()) | |
ray.shutdown() | |
os._exit(1) | |
raise # Re-raise the exception after shutdown | |
# Clean up threads | |
thread.join() | |
print("======== ✅ vllm generate thread ends =========") | |
packing_thread.join() | |
print("======== ✅ data preparation thread ends =========") | |
ray.shutdown() | |
# Ai2 logic: we use /output to store the artifacts of the job, so we | |
# make a copy of the model to `/output` in the end. | |
if ( | |
args.try_auto_save_to_beaker | |
and is_beaker_job() | |
and len(beaker_config.beaker_dataset_id_urls) > 0 | |
and args.output_dir.rstrip("/") != "/output" | |
): | |
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) | |
print("finished training") | |
accelerator = Namespace() | |
accelerator.is_main_process = True # hack | |
if args.push_to_hub: | |
print("Pushing model to hub") | |
push_folder_to_hub( | |
accelerator, | |
args.output_dir, | |
args.hf_repo_id, | |
args.hf_repo_revision, | |
) | |
if __name__ == "__main__": | |
parser = ArgumentParserPlus((Args, TokenizerConfig, ModelConfig)) | |
args, tokenizer_config, model_config = parser.parse_args_into_dataclasses() | |
assert isinstance(args, Args) | |
assert isinstance(tokenizer_config, TokenizerConfig) | |
assert isinstance(model_config, ModelConfig) | |
def reward_fn( | |
responses: List[torch.Tensor], | |
decoded_responses: List[str], | |
ground_truths: List[str], | |
datasets: List[str], | |
finish_reasons: List[str], | |
) -> List[float]: | |
scores = [0] * len(decoded_responses) | |
metrics = {} | |
if args.apply_r1_style_format_reward: | |
with Timer("[Data Preparation Thread] Calculating rewards -- 🧮 Calculating format reward"): | |
format_scores = soft_format_reward_func(decoded_responses, args.r1_style_format_reward) | |
if len(format_scores) != len(scores): | |
raise ValueError(f"{len(format_scores)=} != {len(scores)=}") | |
for i in range(len(format_scores)): | |
scores[i] = format_scores[i] + scores[i] | |
metrics["val/format_scores"] = np.array(format_scores).mean() | |
if args.apply_verifiable_reward: | |
with Timer("[Data Preparation Thread] Calculating rewards -- 🏆 Applying verifiable reward"): | |
verifiable_rewards, per_func_rewards = apply_verifiable_reward( | |
responses, | |
decoded_responses, | |
ground_truths, | |
datasets, | |
reward_mult=args.verification_reward, | |
) | |
if len(verifiable_rewards) != len(scores): | |
raise ValueError(f"{len(verifiable_rewards)=} != {len(scores)=}") | |
for i in range(len(verifiable_rewards)): | |
scores[i] = verifiable_rewards[i] + scores[i] | |
np_verifiable_rewards = np.array(verifiable_rewards) | |
metrics["objective/verifiable_reward"] = np_verifiable_rewards.mean() | |
metrics["objective/verifiable_correct_rate"] = (np_verifiable_rewards > 0.0).mean() | |
# reshuffle around per_func rewards | |
per_func_lists = defaultdict(list) | |
for reward_dict in per_func_rewards: | |
for key, value in reward_dict.items(): | |
per_func_lists[key].append(value) | |
# log per function rewards | |
for key, value in per_func_lists.items(): | |
np_value = np.array(value) | |
metrics[f"objective/{key}_reward"] = np_value.mean() | |
metrics[f"objective/{key}_correct_rate"] = (np_value > 0.0).mean() | |
# this gets applied at the very end since it replaces (rather than adds to) the existing reward. | |
if args.non_stop_penalty: | |
with Timer("[Data Preparation Thread] Calculating rewards -- 🦖 Applying non stop penalty"): | |
assert len(finish_reasons) == len(scores) | |
for i in range(len(finish_reasons)): | |
if finish_reasons[i] != "stop": | |
scores[i] = args.non_stop_penalty_value | |
# @nouha: handle arithmetic reward | |
if args.apply_arithmetic_reward: | |
with Timer("[Data Preparation Thread] Calculating rewards -- 🧮 Calculating arithmetic reward"): | |
arithmetic_rewards = [] | |
for i in range(len(decoded_responses)): | |
# extract the string between <answer> and </answer> | |
decoded_response = decoded_responses[i] | |
answer_start = decoded_response.find("<answer>") + len("<answer>") | |
answer_end = decoded_response.find("</answer>") | |
# normalize the number (e.g., 1,000 -> 1000) | |
try: | |
answer = decoded_response[answer_start:answer_end] | |
answer = answer.replace(",", "").strip() | |
if float(answer) == float(ground_truths[i]): | |
arithmetic_rewards.append(args.arithmetic_reward) | |
scores[i] += args.arithmetic_reward | |
else: | |
arithmetic_rewards.append(0) | |
except: # noqa | |
arithmetic_rewards.append(0) | |
pass # it's ok if things went wrong | |
metrics["objective/arithmetic_score"] = np.array(arithmetic_rewards).mean() | |
metrics["objective/arithmetic_correct_rate"] = (np.array(arithmetic_rewards) > 0.0).mean() | |
return scores, metrics | |
main(args, tokenizer_config, model_config, reward_fn) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment