Skip to content

Instantly share code, notes, and snippets.

@kykim0
Created March 27, 2024 04:05
Show Gist options
  • Select an option

  • Save kykim0/bf4afb6ace402ecb0abebc9d08ccf216 to your computer and use it in GitHub Desktop.

Select an option

Save kykim0/bf4afb6ace402ecb0abebc9d08ccf216 to your computer and use it in GitHub Desktop.
CAPO
class RewardTrainer(Trainer):
r"""
The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the
`transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use
an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset
of paired examples, where each example is a tuple of two sequences. The reward model should be trained to
predict which example in the pair is more relevant to the task at hand.
The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least
if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named
- `input_ids_chosen`
- `attention_mask_chosen`
- `input_ids_rejected`
- `attention_mask_rejected`
Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the
loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/.
If you don't pass a margin, no margin will be used.
"""
_tag_names = ["trl", "reward-trainer"]
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
args: Optional[RewardConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
None,
None,
),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
max_length: Optional[int] = None,
peft_config: Optional[Dict] = None,
):
"""
Initialize RewardTrainer.
Args:
model (`transformers.PreTrainedModel`):
The model to train, preferably an `AutoModelForSequenceClassification`.
args (`RewardConfig`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
train_dataset (`datasets.Dataset`):
The dataset to use for training.
eval_dataset (`datasets.Dataset`):
The dataset to use for evaluation.
tokenizer (`transformers.PreTrainedTokenizerBase`):
The tokenizer to use for training. This argument is required if you want to use the default data collator.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`):
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
callbacks (`List[transformers.TrainerCallback]`):
The callbacks to use for training.
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
max_length (`int`, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
"""
if type(args) == TrainingArguments:
warnings.warn(
"Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.",
FutureWarning,
)
if max_length is not None:
warnings.warn(
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
FutureWarning,
)
else:
if max_length is not None and args.max_length is not None:
raise ValueError(
"You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once."
)
if max_length is not None and args.max_length is None:
warnings.warn(
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
FutureWarning,
)
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
if not isinstance(model, PeftModel):
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
warnings.warn(
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
)
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
model = get_peft_model(model, peft_config)
if compute_metrics is None:
compute_metrics = compute_accuracy
if data_collator is None:
if tokenizer is None:
raise ValueError(
"max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding"
)
if type(args) == TrainingArguments:
if max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
" It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
else:
if max_length is None and args.max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
" It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if max_length is None and args.max_length is not None:
max_length = args.max_length
data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)
if args.remove_unused_columns:
try: # for bc before https://github.com/huggingface/transformers/pull/25435
args.remove_unused_columns = False
except FrozenInstanceError:
args = replace(args, remove_unused_columns=False)
# warn users
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
" we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_reward_data_collator = True
else:
self.use_reward_data_collator = False
super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
tokenizer,
model_init,
compute_metrics,
callbacks,
optimizers,
preprocess_logits_for_metrics,
)
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_reward_data_collator:
warnings.warn(
"The current compute_loss is implemented for RewardDataCollatorWithPadding,"
" if you are using a custom data collator make sure you know what you are doing or"
" implement your own compute_loss method."
)
rewards_chosen = model(
input_ids=inputs["input_ids_chosen"],
attention_mask=inputs["attention_mask_chosen"],
return_dict=True,
)["logits"]
rewards_rejected = model(
input_ids=inputs["input_ids_rejected"],
attention_mask=inputs["attention_mask_rejected"],
return_dict=True,
)["logits"]
####### kykim
chosen_probs = inputs["chosen_probs"]
#######
# calculate loss, optionally modulate with margin
if "margin" in inputs:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
else:
####### kykim
losses = (
-nn.functional.logsigmoid(rewards_chosen - rewards_rejected) * chosen_probs
-nn.functional.logsigmoid(rewards_rejected - rewards_chosen) * (1.0 - chosen_probs)
)
loss = losses.mean()
#######
if return_outputs:
return loss, {
"rewards_chosen": rewards_chosen,
"rewards_rejected": rewards_rejected,
}
return loss
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
with torch.no_grad():
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
if prediction_loss_only:
return (loss, None, None)
loss = loss.detach()
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
logits = nested_detach(logits)
# Stack accepted against rejected, mean over logits
# and softmax to get preferences between accepted and rejected to sum to 1
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
labels = torch.zeros(logits.shape[0])
labels = self._prepare_inputs(labels)
return loss, logits, labels
# Reward modeling on preference data.
from collections import defaultdict
import logging
import math
import os
import random
import sys
from alignment import (
DataArguments,
H4ArgumentParser,
ModelArguments,
RewardArguments,
get_datasets,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
get_tokenizer,
)
import torch
import transformers
from transformers import AutoModelForSequenceClassification, set_seed, TrainerCallback
from trl import RewardConfig, RewardTrainer
logger = logging.getLogger(__name__)
def main():
parser = H4ArgumentParser((ModelArguments, DataArguments, RewardArguments))
model_args, data_args, training_args = parser.parse()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
os.environ["WANDB_PROJECT"] = training_args.run_name
#######
# Setup
#######
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.info(f"Model parameters {model_args}")
logger.info(f"Data parameters {data_args}")
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed for reproducibility
set_seed(training_args.seed)
num_gpus = torch.cuda.device_count()
per_device_batch = training_args.per_device_train_batch_size
grad_accum_steps = training_args.gradient_accumulation_steps
batch_size = num_gpus * per_device_batch * grad_accum_steps
run_name = "-".join([
f"b{batch_size}",
f"lr{training_args.learning_rate}",
f"s{max(0, training_args.max_steps)}",
f"e{training_args.num_train_epochs}",
f"btb{training_args.bt_beta or 'inf'}",
])
if "wandb" in training_args.report_to:
training_args.tracker_kwargs = {"wandb": {"name": run_name}}
training_args.run_name = run_name
training_args.output_dir = os.path.join(
training_args.output_dir,
training_args.run_name,
)
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, token=training_args.hub_token)
logger.info(
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
column_names = list(raw_datasets["train"].features)
#####################################
# Load tokenizer and process datasets
#####################################
data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
tokenizer = get_tokenizer(model_args, data_args)
def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
"chosen_probs": [],
}
for (chosen_messages, rejected_messages, score_chosen, score_rejected) in zip(
examples["chosen"], examples["rejected"], examples["score_chosen"], examples["score_rejected"]):
# Confidence-aware Bradley-Terry model.
# TODOs: 4.4, 2.2, 1.1, 0.5.
chosen_p = 1.0
if training_args.bt_beta:
score_diff = score_chosen - score_rejected
chosen_p = 1.0 / (1.0 + math.exp(-training_args.bt_beta * score_diff))
new_examples["chosen_probs"].append(chosen_p)
chosen_text = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
rejected_text = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
# Ensure consistency with how the model is to be invoked, e.g., during PPO.
chosen_tokenized = tokenizer(chosen_text)
rejected_tokenized = tokenizer(rejected_text)
new_examples["input_ids_chosen"].append(chosen_tokenized["input_ids"])
new_examples["attention_mask_chosen"].append(chosen_tokenized["attention_mask"])
new_examples["input_ids_rejected"].append(rejected_tokenized["input_ids"])
new_examples["attention_mask_rejected"].append(rejected_tokenized["attention_mask"])
return new_examples
# Preprocess and filter examples that are longer than max_length.
def length_filter(x):
return (len(x["input_ids_chosen"]) <= training_args.max_length and
len(x["input_ids_rejected"]) <= training_args.max_length)
raw_datasets = raw_datasets.map(
preprocess_function,
# preprocess_function_raw,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
)
raw_datasets = raw_datasets.filter(length_filter)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
########################
# Instantiate RM trainer
########################
logger.info("*** Load pretrained model ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
logger.info("*** Model loaded! ***")
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
num_labels=1,
**model_kwargs,
)
if getattr(model, "pad_token_id", None) is None:
model.config.pad_token_id = tokenizer.pad_token_id
# Train the model.
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_args),
)
trainer.train(training_args.resume_from_checkpoint)
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process.
if trainer.accelerator.is_main_process:
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"dataset": list(data_args.dataset_mixer.keys()),
"dataset_tags": list(data_args.dataset_mixer.keys()),
}
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
logger.info("*** Training complete ***")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment