Created
September 1, 2024 07:40
-
-
Save kykim0/8d19fc8976e7f998786812b27bac8ae5 to your computer and use it in GitHub Desktop.
Custom reward model to ppo trainer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """PPO v2 trainer.""" | |
| import logging | |
| import random | |
| from accelerate import PartialState | |
| from datasets import load_dataset | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| HfArgumentParser, | |
| ) | |
| from trl import ModelConfig | |
| from trl.trainer.ppov2_trainer import PPOv2Config | |
| from trl.trainer.utils import get_quantization_config, first_true_indices | |
| from configs import DataConfig | |
| from trainer import PPOTrainerOurs | |
| from utils import get_datasets | |
| logger = logging.getLogger(__name__) | |
| def main(): | |
| parser = HfArgumentParser((PPOv2Config, ModelConfig, DataConfig)) | |
| config, model_config, data_config = parser.parse_args_into_dataclasses() | |
| torch_dtype = ( | |
| model_config.torch_dtype if model_config.torch_dtype in ["auto", None] else getattr(torch, model_config.torch_dtype) | |
| ) | |
| quantization_config = get_quantization_config(model_config) | |
| model_kwargs = dict( | |
| revision=model_config.model_revision, | |
| trust_remote_code=model_config.trust_remote_code, | |
| attn_implementation=model_config.attn_implementation, | |
| torch_dtype=torch_dtype, | |
| use_cache=False if config.gradient_checkpointing else True, | |
| quantization_config=quantization_config, | |
| ) | |
| ################### | |
| # Model & Tokenizer | |
| ################### | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_config.model_name_or_path, | |
| padding_side="left", | |
| trust_remote_code=model_config.trust_remote_code, | |
| ) | |
| if getattr(tokenizer, "pad_token", None) is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| assert(tokenizer.chat_template is not None) | |
| # In case the reward model shares the same backbone as the policy, it is | |
| # known to be better to warm-start the value network with the reward model. | |
| # Otherwise, the value network should be based on the policy network. | |
| # See https://arxiv.org/abs/2403.17031 | |
| value_model = AutoModelForSequenceClassification.from_pretrained(config.sft_model_path, num_labels=1, **model_kwargs) | |
| ref_policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path, **model_kwargs) | |
| policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path, **model_kwargs) | |
| # Set up the reward model and the lambda. | |
| reward_model = AutoModelForSequenceClassification.from_pretrained( | |
| config.reward_model_path, num_labels=1, **model_kwargs | |
| ) | |
| reward_tokenizer = AutoTokenizer.from_pretrained( | |
| config.reward_model_path, | |
| trust_remote_code=model_config.trust_remote_code, | |
| ) | |
| if getattr(reward_tokenizer, "pad_token", None) is None: | |
| reward_tokenizer.pad_token = reward_tokenizer.eos_token | |
| reward_tokenizer.pad_token_id = reward_tokenizer.eos_token_id | |
| reward_model.config.pad_token_id = reward_tokenizer.pad_token_id | |
| ######### | |
| # Dataset | |
| ######### | |
| system_role_supported = "gemma-2" not in config.sft_model_path | |
| def dataset_map_fn(example, tokenizer): | |
| messages = [{"role": "user", "content": example["prompt"]}] | |
| if system_role_supported: messages.insert(0, {"role": "system", "content": ""}) | |
| query = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| r_query = reward_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| example_out = { | |
| "input_ids": tokenizer.encode(query), | |
| # Could also encode but use str to reduce padding in the trainer. | |
| "r_query_text": r_query, | |
| } | |
| return example_out | |
| # We retrieve the dataloader by calling the `build_dataset` function. | |
| # Compute that only on the main process for faster data processing. | |
| # See: https://github.com/huggingface/trl/pull/1255 | |
| with PartialState().local_main_process_first(): | |
| column_names = ["prompt", "prompt_id", "chosen", "rejected", "messages", "score_chosen", "score_rejected", "source"] | |
| data_config.dataset_mixer = { | |
| "allenai/ultrafeedback_binarized_cleaned": 1.0, | |
| } | |
| raw_datasets = get_datasets(data_config, splits=["train_gen", "test_gen"], columns_to_keep=column_names, shuffle=True) | |
| raw_datasets = raw_datasets.map( | |
| dataset_map_fn, | |
| fn_kwargs={"tokenizer": tokenizer}, | |
| remove_columns=column_names, | |
| num_proc=data_config.preprocessing_num_workers, | |
| ) | |
| raw_datasets.set_format(type="torch") | |
| train_dataset = raw_datasets["train"] | |
| eval_dataset = raw_datasets["test"] | |
| # Print out 3 random samples. | |
| for index in random.sample(range(len(train_dataset)), 3): | |
| logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]}") | |
| ########## | |
| # Training | |
| ########## | |
| trainer = PPOTrainerOurs( | |
| config=config, | |
| tokenizer=tokenizer, | |
| policy=policy, | |
| ref_policy=ref_policy, | |
| reward_model=reward_model, | |
| reward_tokenizer=reward_tokenizer, | |
| value_model=value_model, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| ) | |
| trainer.train() | |
| trainer.save_model(config.output_dir) | |
| if config.push_to_hub: | |
| trainer.push_to_hub() | |
| trainer.generate_completions() | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| import gc | |
| import math | |
| import time | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| from accelerate.utils import gather_object | |
| from datasets import Dataset | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import ( | |
| DataCollatorWithPadding, | |
| GenerationConfig, | |
| PreTrainedTokenizer, | |
| TrainerCallback, | |
| ) | |
| from trl.core import masked_mean, masked_whiten | |
| from trl.models.utils import unwrap_model_for_generation | |
| from trl.trainer.ppov2_trainer import INVALID_LOGPROB, PPOv2Config, PPOv2Trainer | |
| from trl.trainer.utils import ( | |
| batch_generation, | |
| first_true_indices, | |
| forward, | |
| get_reward, | |
| print_rich_table, | |
| truncate_response, | |
| ) | |
| @dataclass | |
| class RewardDataCollatorWithPadding: | |
| tokenizer: PreTrainedTokenizer | |
| padding: Union[bool, str] = True | |
| max_length: Optional[int] = None | |
| pad_to_multiple_of: Optional[int] = None | |
| return_tensors: str = "pt" | |
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| p_features = [] | |
| r_queries = [] | |
| for feature in features: | |
| p_features.append({"input_ids": feature["input_ids"]}) | |
| r_queries.append(feature["r_query_text"]) | |
| p_batch = self.tokenizer.pad( | |
| p_features, | |
| padding=self.padding, | |
| max_length=self.max_length, | |
| pad_to_multiple_of=self.pad_to_multiple_of, | |
| return_tensors=self.return_tensors, | |
| ) | |
| batch = { | |
| "input_ids": p_batch["input_ids"], | |
| "r_query_text": r_queries, | |
| } | |
| return batch | |
| def get_score( | |
| model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Computes the reward logits and the rewards for a given model and query responses. | |
| """ | |
| attention_mask = query_responses != pad_token_id | |
| position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum | |
| lm_backbone = getattr(model, model.base_model_prefix) | |
| input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) | |
| output = lm_backbone( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| return_dict=True, | |
| output_hidden_states=True, | |
| # use_cache=False, # otherwise mistral-based RM would error out | |
| ) | |
| reward_logits = model.score(output.hidden_states[-1]) | |
| sequence_lengths = torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1 | |
| sequence_lengths = sequence_lengths % input_ids.shape[-1] | |
| # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 | |
| return ( | |
| reward_logits, | |
| reward_logits[ | |
| torch.arange(reward_logits.size(0), device=reward_logits.device), | |
| sequence_lengths, | |
| ].squeeze(-1), | |
| sequence_lengths, | |
| ) | |
| class PPOTrainerOurs(PPOv2Trainer): | |
| def __init__( | |
| self, | |
| config: PPOv2Config, | |
| tokenizer: PreTrainedTokenizer, | |
| policy: nn.Module, | |
| ref_policy: nn.Module, | |
| reward_model: nn.Module, | |
| reward_tokenizer: PreTrainedTokenizer, | |
| train_dataset: Dataset, | |
| value_model: Optional[nn.Module] = None, | |
| data_collator: Optional[DataCollatorWithPadding] = None, | |
| eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, | |
| # less commonly used | |
| optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), | |
| callbacks: Optional[List[TrainerCallback]] = None, | |
| ) -> None: | |
| data_collator = data_collator or RewardDataCollatorWithPadding(tokenizer=tokenizer) | |
| super().__init__( | |
| config=config, | |
| tokenizer=tokenizer, | |
| policy=policy, | |
| ref_policy=ref_policy, | |
| reward_model=reward_model, | |
| train_dataset=train_dataset, | |
| value_model=value_model, | |
| data_collator=data_collator, | |
| eval_dataset=eval_dataset, | |
| optimizers=optimizers, | |
| callbacks=callbacks, | |
| ) | |
| self.reward_model = reward_model | |
| self.reward_tokenizer = reward_tokenizer | |
| def train(self): | |
| args = self.args | |
| accelerator = self.accelerator | |
| optimizer = self.optimizer | |
| model = self.model | |
| ref_policy = self.ref_policy | |
| r_model = self.reward_model | |
| tokenizer = self.tokenizer | |
| r_tokenizer = self.reward_tokenizer | |
| dataloader = self.dataloader | |
| device = accelerator.device | |
| def repeat_generator(): | |
| while True: | |
| yield from dataloader | |
| iter_dataloader = iter(repeat_generator()) | |
| generation_config = GenerationConfig( | |
| max_new_tokens=args.response_length, | |
| min_new_tokens=args.response_length, | |
| temperature=(args.temperature + 1e-7), | |
| top_k=0.0, | |
| top_p=1.0, | |
| do_sample=True, | |
| ) | |
| accelerator.print("===training policy===") | |
| start_time = time.time() | |
| stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) | |
| approxkl_stats = torch.zeros(stats_shape, device=device) | |
| pg_clipfrac_stats = torch.zeros(stats_shape, device=device) | |
| pg_loss_stats = torch.zeros(stats_shape, device=device) | |
| vf_loss_stats = torch.zeros(stats_shape, device=device) | |
| vf_clipfrac_stats = torch.zeros(stats_shape, device=device) | |
| entropy_stats = torch.zeros(stats_shape, device=device) | |
| ratio_stats = torch.zeros(stats_shape, device=device) | |
| model.train() | |
| # trainer state initialization | |
| self.state.global_step = 0 | |
| self.state.episode = 0 | |
| self.state.max_steps = args.num_total_batches * args.num_mini_batches | |
| self.state.num_train_epochs = args.total_episodes / self.train_dataset_len | |
| # Compute absolute values for logging, eval, and save if given as ratio | |
| if args.logging_steps is not None: | |
| if args.logging_steps < 1: | |
| self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) | |
| else: | |
| self.state.logging_steps = args.logging_steps | |
| if args.eval_steps is not None: | |
| if args.eval_steps < 1: | |
| self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) | |
| else: | |
| self.state.eval_steps = args.eval_steps | |
| if args.save_steps is not None: | |
| if args.save_steps < 1: | |
| self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) | |
| else: | |
| self.state.save_steps = args.save_steps | |
| self.control = self.callback_handler.on_train_begin(args, self.state, self.control) | |
| for update in range(1, args.num_total_batches + 1): | |
| self.state.episode += 1 * args.batch_size | |
| data = next(iter_dataloader) | |
| with torch.no_grad(): | |
| queries = data["input_ids"].to(device) | |
| r_query_texts = data["r_query_text"] | |
| context_length = queries.shape[1] | |
| responses = [] | |
| postprocessed_responses = [] | |
| logprobs = [] | |
| ref_logprobs = [] | |
| scores = [] | |
| sequence_lengths = [] | |
| values = [] | |
| with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: | |
| query_responses, logitss = batch_generation( | |
| unwrapped_model.policy, | |
| queries, | |
| args.local_rollout_forward_batch_size, | |
| tokenizer.pad_token_id, | |
| generation_config, | |
| ) | |
| for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): | |
| query = queries[i : i + args.local_rollout_forward_batch_size] | |
| r_query_text = r_query_texts[i : i + args.local_rollout_forward_batch_size] | |
| query_response = query_responses[i : i + args.local_rollout_forward_batch_size] | |
| response = query_response[:, context_length:] | |
| logits = logitss[i : i + args.local_rollout_forward_batch_size] | |
| all_logprob = F.log_softmax(logits, dim=-1) | |
| logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) | |
| del logits, all_logprob | |
| torch.cuda.empty_cache() | |
| ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id) | |
| ref_logits = ref_output.logits[:, context_length - 1 : -1] | |
| ref_logits /= args.temperature + 1e-7 | |
| ref_all_logprob = F.log_softmax(ref_logits, dim=-1) | |
| ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) | |
| del ref_output, ref_logits, ref_all_logprob | |
| torch.cuda.empty_cache() | |
| # Response Processing 1. truncate response after the first occurrence of `stop_token_id` | |
| postprocessed_response = response | |
| if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 | |
| postprocessed_response = truncate_response( | |
| args.stop_token_id, tokenizer.pad_token_id, response | |
| ) | |
| # Response Processing 2. run reward model on the truncated responses | |
| postprocessed_query_response = torch.cat((query, postprocessed_response), 1) | |
| sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 | |
| unwrapped_value_model = accelerator.unwrap_model(model).value_model | |
| full_value, _, _ = get_reward( | |
| unwrapped_value_model, query_response, tokenizer.pad_token_id, context_length | |
| ) | |
| value = full_value[:, context_length - 1 : -1].squeeze(-1) | |
| # TODO(kykim): May be not skip_special_tokens but remove pad symbols afterwards. | |
| postprocessed_response_text = tokenizer.batch_decode(postprocessed_response, skip_special_tokens=True) | |
| r_query_response_text = [q + r for q, r in zip(r_query_text, postprocessed_response_text)] | |
| r_query_response_input_ids = r_tokenizer(r_query_response_text, padding=True)["input_ids"] | |
| r_query_response = torch.tensor(r_query_response_input_ids).to(device) | |
| _, score, _ = get_score(r_model, r_query_response, r_tokenizer.pad_token_id) | |
| responses.append(response) | |
| postprocessed_responses.append(postprocessed_response) | |
| logprobs.append(logprob) | |
| ref_logprobs.append(ref_logprob) | |
| sequence_lengths.append(sequence_length) | |
| scores.append(score) | |
| values.append(value) | |
| responses = torch.cat(responses, 0) | |
| postprocessed_responses = torch.cat(postprocessed_responses, 0) | |
| logprobs = torch.cat(logprobs, 0) | |
| ref_logprobs = torch.cat(ref_logprobs, 0) | |
| sequence_lengths = torch.cat(sequence_lengths, 0) | |
| scores = torch.cat(scores, 0) | |
| values = torch.cat(values, 0) | |
| del (logprob, ref_logprob, full_value, value, score, unwrapped_model) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Response Processing 3. filter response. Ensure that the sample contains stop_token_id | |
| # responses not passing that filter will receive a low (fixed) score | |
| # only query humans on responses that pass that filter | |
| contain_eos_token = torch.any(postprocessed_responses[:, context_length:] == tokenizer.eos_token_id, dim=-1) | |
| if args.non_eos_penalty: | |
| scores = torch.where(contain_eos_token, scores, args.penalty_reward_value) | |
| # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") | |
| # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw | |
| response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) | |
| padding_mask = response_idxs > sequence_lengths.unsqueeze(1) | |
| logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) | |
| ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) | |
| sequence_lengths_p1 = sequence_lengths + 1 | |
| padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) | |
| values = torch.masked_fill(values, padding_mask_p1, 0) | |
| # 4. compute rewards | |
| kl = logprobs - ref_logprobs | |
| non_score_reward = -args.kl_coef * kl | |
| rewards = non_score_reward.clone() | |
| actual_start = torch.arange(rewards.size(0), device=rewards.device) | |
| actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) | |
| rewards[[actual_start, actual_end]] += scores | |
| # 5. whiten rewards | |
| if args.whiten_rewards: | |
| rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) | |
| rewards = torch.masked_fill(rewards, padding_mask_p1, 0) | |
| # 6. compute advantages and returns | |
| lastgaelam = 0 | |
| advantages_reversed = [] | |
| gen_length = responses.shape[1] | |
| for t in reversed(range(gen_length)): | |
| nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 | |
| delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] | |
| lastgaelam = delta + args.gamma * args.lam * lastgaelam | |
| advantages_reversed.append(lastgaelam) | |
| advantages = torch.stack(advantages_reversed[::-1], axis=1) | |
| returns = advantages + values | |
| advantages = masked_whiten(advantages, ~padding_mask) | |
| advantages = torch.masked_fill(advantages, padding_mask, 0) | |
| torch.cuda.empty_cache() | |
| # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch | |
| for ppo_epoch_idx in range(args.num_ppo_epochs): | |
| b_inds = np.random.permutation(args.local_batch_size) | |
| minibatch_idx = 0 | |
| for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): | |
| mini_batch_end = mini_batch_start + args.local_mini_batch_size | |
| mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] | |
| gradient_accumulation_idx = 0 | |
| for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): | |
| with accelerator.accumulate(model): | |
| micro_batch_end = micro_batch_start + args.per_device_train_batch_size | |
| micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] | |
| mb_advantage = advantages[micro_batch_inds] | |
| mb_responses = responses[micro_batch_inds] | |
| mb_query_responses = query_responses[micro_batch_inds] | |
| mb_logprobs = logprobs[micro_batch_inds] | |
| mb_return = returns[micro_batch_inds] | |
| mb_values = values[micro_batch_inds] | |
| output, vpred_temp = forward(model, mb_query_responses, tokenizer.pad_token_id) | |
| logits = output.logits[:, context_length - 1 : -1] | |
| logits /= args.temperature + 1e-7 | |
| new_all_logprobs = F.log_softmax(logits, dim=-1) | |
| new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) | |
| new_logprobs = torch.masked_fill( | |
| new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB | |
| ) | |
| vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) | |
| vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) | |
| vpredclipped = torch.clamp( | |
| vpred, | |
| mb_values - args.cliprange_value, | |
| mb_values + args.cliprange_value, | |
| ) | |
| vf_losses1 = torch.square(vpred - mb_return) | |
| vf_losses2 = torch.square(vpredclipped - mb_return) | |
| vf_loss_max = torch.max(vf_losses1, vf_losses2) | |
| vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds]) | |
| vf_clipfrac = masked_mean( | |
| (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds] | |
| ) | |
| logprobs_diff = new_logprobs - mb_logprobs | |
| ratio = torch.exp(logprobs_diff) | |
| pg_losses = -mb_advantage * ratio | |
| pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) | |
| pg_loss_max = torch.max(pg_losses, pg_losses2) | |
| pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) | |
| loss = pg_loss + args.vf_coef * vf_loss | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| with torch.no_grad(): | |
| pg_clipfrac = masked_mean( | |
| (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] | |
| ) | |
| prob_dist = torch.nn.functional.softmax(logits, dim=-1) | |
| entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) | |
| approxkl = 0.5 * (logprobs_diff**2).mean() | |
| approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl | |
| pg_clipfrac_stats[ | |
| ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx | |
| ] = pg_clipfrac | |
| pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss | |
| vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss | |
| vf_clipfrac_stats[ | |
| ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx | |
| ] = vf_clipfrac | |
| entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() | |
| ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() | |
| gradient_accumulation_idx += 1 | |
| minibatch_idx += 1 | |
| # del everything and empty cache | |
| # fmt: off | |
| del ( | |
| output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped, | |
| vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, | |
| pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, | |
| mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, | |
| ) | |
| # fmt: on | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| mean_kl = kl.sum(1).mean() | |
| mean_entropy = (-logprobs).sum(1).mean() | |
| mean_non_score_reward = non_score_reward.sum(1).mean() | |
| rlhf_reward = mean_non_score_reward + scores.mean() | |
| eps = int(self.state.episode / (time.time() - start_time)) | |
| metrics = {} | |
| metrics["eps"] = eps | |
| metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() | |
| metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() | |
| metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() | |
| metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() | |
| metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() | |
| metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() | |
| metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() | |
| metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() | |
| metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item() | |
| metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() | |
| metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() | |
| metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() | |
| metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() | |
| metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item() | |
| metrics["lr"] = self.lr_scheduler.get_last_lr()[0] | |
| metrics["episode"] = self.state.episode | |
| self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log | |
| self.state.global_step += 1 | |
| self.log(metrics) | |
| self.lr_scheduler.step() | |
| self.control = self.callback_handler.on_step_end(args, self.state, self.control) | |
| if self.control.should_save: | |
| self._save_checkpoint(model, trial=None, metrics=metrics) | |
| self.control = self.callback_handler.on_save(self.args, self.state, self.control) | |
| del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: | |
| self.generate_completions(sampling=True) | |
| torch.cuda.empty_cache() | |
| del ( | |
| query_responses, | |
| responses, | |
| postprocessed_responses, | |
| logprobs, | |
| ref_logprobs, | |
| values, | |
| sequence_lengths, | |
| contain_eos_token, | |
| sequence_lengths_p1, | |
| response_idxs, | |
| padding_mask, | |
| padding_mask_p1, | |
| rewards, | |
| actual_start, | |
| actual_end, | |
| advantages, | |
| returns, | |
| ) | |
| torch.cuda.empty_cache() | |
| # HF trainer specifics | |
| self.control = self.callback_handler.on_train_end(args, self.state, self.control) | |
| if self.control.should_save: | |
| self._save_checkpoint(model, trial=None, metrics=None) | |
| self.control = self.callback_handler.on_save(self.args, self.state, self.control) | |
| def generate_completions(self, sampling: bool = False): | |
| args = self.args | |
| tokenizer = self.tokenizer | |
| r_model = self.reward_model | |
| r_tokenizer = self.reward_tokenizer | |
| device = self.accelerator.device | |
| generation_config = GenerationConfig( | |
| max_new_tokens=self.args.response_length, | |
| temperature=(0.01 + 1e-7), | |
| top_k=0.0, | |
| top_p=1.0, | |
| do_sample=True, | |
| ) | |
| table = defaultdict(list) | |
| with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: | |
| for batch in self.eval_dataloader: | |
| query = batch["input_ids"] | |
| r_query_text = batch["r_query_text"] | |
| with torch.no_grad(): | |
| context_length = query.shape[1] | |
| query_response, _ = batch_generation( | |
| unwrapped_model.policy, | |
| query, | |
| query.shape[0], | |
| tokenizer.pad_token_id, | |
| generation_config, | |
| ) | |
| response = query_response[:, context_length:] | |
| postprocessed_response = response | |
| if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 | |
| postprocessed_response = truncate_response( | |
| args.stop_token_id, tokenizer.pad_token_id, response | |
| ) | |
| table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) | |
| table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) | |
| postprocessed_response_text = tokenizer.batch_decode(postprocessed_response, skip_special_tokens=True) | |
| r_query_response_text = [q + r for q, r in zip(r_query_text, postprocessed_response_text)] | |
| r_query_response = torch.tensor(r_tokenizer(r_query_response_text, padding=True)["input_ids"]).to(device) | |
| _, score, _ = get_score(r_model, r_query_response, r_tokenizer.pad_token_id) | |
| table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) | |
| if sampling: | |
| break | |
| df = pd.DataFrame(table) | |
| if self.accelerator.is_main_process: | |
| print_rich_table(df.iloc[0 : 0 + 5]) | |
| if "wandb" in args.report_to: | |
| import wandb | |
| if wandb.run is not None: | |
| wandb.log({"completions": wandb.Table(dataframe=df)}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment