Created
March 27, 2024 04:05
-
-
Save kykim0/bf4afb6ace402ecb0abebc9d08ccf216 to your computer and use it in GitHub Desktop.
CAPO
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class RewardTrainer(Trainer): | |
| r""" | |
| The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the | |
| `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use | |
| an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset | |
| of paired examples, where each example is a tuple of two sequences. The reward model should be trained to | |
| predict which example in the pair is more relevant to the task at hand. | |
| The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least | |
| if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named | |
| - `input_ids_chosen` | |
| - `attention_mask_chosen` | |
| - `input_ids_rejected` | |
| - `attention_mask_rejected` | |
| Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the | |
| loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/. | |
| If you don't pass a margin, no margin will be used. | |
| """ | |
| _tag_names = ["trl", "reward-trainer"] | |
| def __init__( | |
| self, | |
| model: Optional[Union[PreTrainedModel, nn.Module]] = None, | |
| args: Optional[RewardConfig] = None, | |
| data_collator: Optional[DataCollator] = None, | |
| train_dataset: Optional[Dataset] = None, | |
| eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, | |
| tokenizer: Optional[PreTrainedTokenizerBase] = None, | |
| model_init: Optional[Callable[[], PreTrainedModel]] = None, | |
| compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, | |
| callbacks: Optional[List[TrainerCallback]] = None, | |
| optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( | |
| None, | |
| None, | |
| ), | |
| preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, | |
| max_length: Optional[int] = None, | |
| peft_config: Optional[Dict] = None, | |
| ): | |
| """ | |
| Initialize RewardTrainer. | |
| Args: | |
| model (`transformers.PreTrainedModel`): | |
| The model to train, preferably an `AutoModelForSequenceClassification`. | |
| args (`RewardConfig`): | |
| The arguments to use for training. | |
| data_collator (`transformers.DataCollator`): | |
| The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used | |
| which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. | |
| train_dataset (`datasets.Dataset`): | |
| The dataset to use for training. | |
| eval_dataset (`datasets.Dataset`): | |
| The dataset to use for evaluation. | |
| tokenizer (`transformers.PreTrainedTokenizerBase`): | |
| The tokenizer to use for training. This argument is required if you want to use the default data collator. | |
| model_init (`Callable[[], transformers.PreTrainedModel]`): | |
| The model initializer to use for training. If None is specified, the default model initializer will be used. | |
| compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`): | |
| The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used. | |
| callbacks (`List[transformers.TrainerCallback]`): | |
| The callbacks to use for training. | |
| optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): | |
| The optimizer and scheduler to use for training. | |
| preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): | |
| The function to use to preprocess the logits before computing the metrics. | |
| max_length (`int`, defaults to `None`): | |
| The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. | |
| peft_config (`Dict`, defaults to `None`): | |
| The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. | |
| """ | |
| if type(args) == TrainingArguments: | |
| warnings.warn( | |
| "Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.", | |
| FutureWarning, | |
| ) | |
| if max_length is not None: | |
| warnings.warn( | |
| "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.", | |
| FutureWarning, | |
| ) | |
| else: | |
| if max_length is not None and args.max_length is not None: | |
| raise ValueError( | |
| "You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once." | |
| ) | |
| if max_length is not None and args.max_length is None: | |
| warnings.warn( | |
| "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.", | |
| FutureWarning, | |
| ) | |
| if not is_peft_available() and peft_config is not None: | |
| raise ValueError( | |
| "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" | |
| ) | |
| elif is_peft_available() and peft_config is not None: | |
| if not isinstance(model, PeftModel): | |
| if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False): | |
| _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( | |
| inspect.signature(prepare_model_for_kbit_training).parameters | |
| ) | |
| preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} | |
| if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: | |
| warnings.warn( | |
| "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. " | |
| "please update to the latest version of peft to use `gradient_checkpointing_kwargs`." | |
| ) | |
| elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: | |
| preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs | |
| model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) | |
| model = get_peft_model(model, peft_config) | |
| if compute_metrics is None: | |
| compute_metrics = compute_accuracy | |
| if data_collator is None: | |
| if tokenizer is None: | |
| raise ValueError( | |
| "max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding" | |
| ) | |
| if type(args) == TrainingArguments: | |
| if max_length is None: | |
| warnings.warn( | |
| "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." | |
| " It will be set to `512` by default, but you should do it yourself in the future.", | |
| UserWarning, | |
| ) | |
| max_length = 512 | |
| else: | |
| if max_length is None and args.max_length is None: | |
| warnings.warn( | |
| "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." | |
| " It will be set to `512` by default, but you should do it yourself in the future.", | |
| UserWarning, | |
| ) | |
| max_length = 512 | |
| if max_length is None and args.max_length is not None: | |
| max_length = args.max_length | |
| data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length) | |
| if args.remove_unused_columns: | |
| try: # for bc before https://github.com/huggingface/transformers/pull/25435 | |
| args.remove_unused_columns = False | |
| except FrozenInstanceError: | |
| args = replace(args, remove_unused_columns=False) | |
| # warn users | |
| warnings.warn( | |
| "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" | |
| " we have set it for you, but you should do it yourself in the future.", | |
| UserWarning, | |
| ) | |
| self.use_reward_data_collator = True | |
| else: | |
| self.use_reward_data_collator = False | |
| super().__init__( | |
| model, | |
| args, | |
| data_collator, | |
| train_dataset, | |
| eval_dataset, | |
| tokenizer, | |
| model_init, | |
| compute_metrics, | |
| callbacks, | |
| optimizers, | |
| preprocess_logits_for_metrics, | |
| ) | |
| # Add tags for models that have been loaded with the correct transformers version | |
| if hasattr(self.model, "add_model_tags"): | |
| self.model.add_model_tags(self._tag_names) | |
| def compute_loss( | |
| self, | |
| model: Union[PreTrainedModel, nn.Module], | |
| inputs: Dict[str, Union[torch.Tensor, Any]], | |
| return_outputs=False, | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: | |
| if not self.use_reward_data_collator: | |
| warnings.warn( | |
| "The current compute_loss is implemented for RewardDataCollatorWithPadding," | |
| " if you are using a custom data collator make sure you know what you are doing or" | |
| " implement your own compute_loss method." | |
| ) | |
| rewards_chosen = model( | |
| input_ids=inputs["input_ids_chosen"], | |
| attention_mask=inputs["attention_mask_chosen"], | |
| return_dict=True, | |
| )["logits"] | |
| rewards_rejected = model( | |
| input_ids=inputs["input_ids_rejected"], | |
| attention_mask=inputs["attention_mask_rejected"], | |
| return_dict=True, | |
| )["logits"] | |
| ####### kykim | |
| chosen_probs = inputs["chosen_probs"] | |
| ####### | |
| # calculate loss, optionally modulate with margin | |
| if "margin" in inputs: | |
| loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() | |
| else: | |
| ####### kykim | |
| losses = ( | |
| -nn.functional.logsigmoid(rewards_chosen - rewards_rejected) * chosen_probs | |
| -nn.functional.logsigmoid(rewards_rejected - rewards_chosen) * (1.0 - chosen_probs) | |
| ) | |
| loss = losses.mean() | |
| ####### | |
| if return_outputs: | |
| return loss, { | |
| "rewards_chosen": rewards_chosen, | |
| "rewards_rejected": rewards_rejected, | |
| } | |
| return loss | |
| def prediction_step( | |
| self, | |
| model: Union[PreTrainedModel, nn.Module], | |
| inputs: Dict[str, Union[torch.Tensor, Any]], | |
| prediction_loss_only: bool, | |
| ignore_keys: Optional[List[str]] = None, | |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| inputs = self._prepare_inputs(inputs) | |
| if ignore_keys is None: | |
| if hasattr(self.model, "config"): | |
| ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) | |
| else: | |
| ignore_keys = [] | |
| with torch.no_grad(): | |
| loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) | |
| if prediction_loss_only: | |
| return (loss, None, None) | |
| loss = loss.detach() | |
| logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) | |
| logits = nested_detach(logits) | |
| # Stack accepted against rejected, mean over logits | |
| # and softmax to get preferences between accepted and rejected to sum to 1 | |
| logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T | |
| labels = torch.zeros(logits.shape[0]) | |
| labels = self._prepare_inputs(labels) | |
| return loss, logits, labels |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Reward modeling on preference data. | |
| from collections import defaultdict | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import sys | |
| from alignment import ( | |
| DataArguments, | |
| H4ArgumentParser, | |
| ModelArguments, | |
| RewardArguments, | |
| get_datasets, | |
| get_kbit_device_map, | |
| get_peft_config, | |
| get_quantization_config, | |
| get_tokenizer, | |
| ) | |
| import torch | |
| import transformers | |
| from transformers import AutoModelForSequenceClassification, set_seed, TrainerCallback | |
| from trl import RewardConfig, RewardTrainer | |
| logger = logging.getLogger(__name__) | |
| def main(): | |
| parser = H4ArgumentParser((ModelArguments, DataArguments, RewardArguments)) | |
| model_args, data_args, training_args = parser.parse() | |
| training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) | |
| os.environ["WANDB_PROJECT"] = training_args.run_name | |
| ####### | |
| # Setup | |
| ####### | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| log_level = training_args.get_process_log_level() | |
| logger.setLevel(log_level) | |
| transformers.utils.logging.set_verbosity(log_level) | |
| transformers.utils.logging.enable_default_handler() | |
| transformers.utils.logging.enable_explicit_format() | |
| # Log on each process the small summary: | |
| logger.info(f"Model parameters {model_args}") | |
| logger.info(f"Data parameters {data_args}") | |
| logger.info(f"Training/evaluation parameters {training_args}") | |
| # Set seed for reproducibility | |
| set_seed(training_args.seed) | |
| num_gpus = torch.cuda.device_count() | |
| per_device_batch = training_args.per_device_train_batch_size | |
| grad_accum_steps = training_args.gradient_accumulation_steps | |
| batch_size = num_gpus * per_device_batch * grad_accum_steps | |
| run_name = "-".join([ | |
| f"b{batch_size}", | |
| f"lr{training_args.learning_rate}", | |
| f"s{max(0, training_args.max_steps)}", | |
| f"e{training_args.num_train_epochs}", | |
| f"btb{training_args.bt_beta or 'inf'}", | |
| ]) | |
| if "wandb" in training_args.report_to: | |
| training_args.tracker_kwargs = {"wandb": {"name": run_name}} | |
| training_args.run_name = run_name | |
| training_args.output_dir = os.path.join( | |
| training_args.output_dir, | |
| training_args.run_name, | |
| ) | |
| ############### | |
| # Load datasets | |
| ############### | |
| raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, token=training_args.hub_token) | |
| logger.info( | |
| f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" | |
| ) | |
| column_names = list(raw_datasets["train"].features) | |
| ##################################### | |
| # Load tokenizer and process datasets | |
| ##################################### | |
| data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn | |
| tokenizer = get_tokenizer(model_args, data_args) | |
| def preprocess_function(examples): | |
| new_examples = { | |
| "input_ids_chosen": [], | |
| "attention_mask_chosen": [], | |
| "input_ids_rejected": [], | |
| "attention_mask_rejected": [], | |
| "chosen_probs": [], | |
| } | |
| for (chosen_messages, rejected_messages, score_chosen, score_rejected) in zip( | |
| examples["chosen"], examples["rejected"], examples["score_chosen"], examples["score_rejected"]): | |
| # Confidence-aware Bradley-Terry model. | |
| # TODOs: 4.4, 2.2, 1.1, 0.5. | |
| chosen_p = 1.0 | |
| if training_args.bt_beta: | |
| score_diff = score_chosen - score_rejected | |
| chosen_p = 1.0 / (1.0 + math.exp(-training_args.bt_beta * score_diff)) | |
| new_examples["chosen_probs"].append(chosen_p) | |
| chosen_text = tokenizer.apply_chat_template(chosen_messages, tokenize=False) | |
| rejected_text = tokenizer.apply_chat_template(rejected_messages, tokenize=False) | |
| # Ensure consistency with how the model is to be invoked, e.g., during PPO. | |
| chosen_tokenized = tokenizer(chosen_text) | |
| rejected_tokenized = tokenizer(rejected_text) | |
| new_examples["input_ids_chosen"].append(chosen_tokenized["input_ids"]) | |
| new_examples["attention_mask_chosen"].append(chosen_tokenized["attention_mask"]) | |
| new_examples["input_ids_rejected"].append(rejected_tokenized["input_ids"]) | |
| new_examples["attention_mask_rejected"].append(rejected_tokenized["attention_mask"]) | |
| return new_examples | |
| # Preprocess and filter examples that are longer than max_length. | |
| def length_filter(x): | |
| return (len(x["input_ids_chosen"]) <= training_args.max_length and | |
| len(x["input_ids_rejected"]) <= training_args.max_length) | |
| raw_datasets = raw_datasets.map( | |
| preprocess_function, | |
| # preprocess_function_raw, | |
| batched=True, | |
| num_proc=data_args.preprocessing_num_workers, | |
| remove_columns=column_names, | |
| ) | |
| raw_datasets = raw_datasets.filter(length_filter) | |
| train_dataset = raw_datasets["train"] | |
| eval_dataset = raw_datasets["test"] | |
| ######################## | |
| # Instantiate RM trainer | |
| ######################## | |
| logger.info("*** Load pretrained model ***") | |
| torch_dtype = ( | |
| model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) | |
| ) | |
| quantization_config = get_quantization_config(model_args) | |
| model_kwargs = dict( | |
| revision=model_args.model_revision, | |
| trust_remote_code=model_args.trust_remote_code, | |
| torch_dtype=torch_dtype, | |
| use_cache=False if training_args.gradient_checkpointing else True, | |
| device_map=get_kbit_device_map() if quantization_config is not None else None, | |
| quantization_config=quantization_config, | |
| ) | |
| logger.info("*** Model loaded! ***") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_args.model_name_or_path, | |
| num_labels=1, | |
| **model_kwargs, | |
| ) | |
| if getattr(model, "pad_token_id", None) is None: | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| # Train the model. | |
| trainer = RewardTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| peft_config=get_peft_config(model_args), | |
| ) | |
| trainer.train(training_args.resume_from_checkpoint) | |
| ################################## | |
| # Save model and create model card | |
| ################################## | |
| logger.info("*** Save model ***") | |
| trainer.save_model(training_args.output_dir) | |
| logger.info(f"Model saved to {training_args.output_dir}") | |
| # Save everything else on main process. | |
| if trainer.accelerator.is_main_process: | |
| kwargs = { | |
| "finetuned_from": model_args.model_name_or_path, | |
| "dataset": list(data_args.dataset_mixer.keys()), | |
| "dataset_tags": list(data_args.dataset_mixer.keys()), | |
| } | |
| trainer.create_model_card(**kwargs) | |
| # Restore k,v cache for fast inference | |
| trainer.model.config.use_cache = True | |
| trainer.model.config.save_pretrained(training_args.output_dir) | |
| logger.info("*** Training complete ***") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment