Skip to content

Instantly share code, notes, and snippets.

@kykim0
Created September 1, 2024 07:40
Show Gist options
  • Select an option

  • Save kykim0/8d19fc8976e7f998786812b27bac8ae5 to your computer and use it in GitHub Desktop.

Select an option

Save kykim0/8d19fc8976e7f998786812b27bac8ae5 to your computer and use it in GitHub Desktop.
Custom reward model to ppo trainer
"""PPO v2 trainer."""
import logging
import random
from accelerate import PartialState
from datasets import load_dataset
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
)
from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config
from trl.trainer.utils import get_quantization_config, first_true_indices
from configs import DataConfig
from trainer import PPOTrainerOurs
from utils import get_datasets
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((PPOv2Config, ModelConfig, DataConfig))
config, model_config, data_config = parser.parse_args_into_dataclasses()
torch_dtype = (
model_config.torch_dtype if model_config.torch_dtype in ["auto", None] else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if config.gradient_checkpointing else True,
quantization_config=quantization_config,
)
###################
# Model & Tokenizer
###################
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
)
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
assert(tokenizer.chat_template is not None)
# In case the reward model shares the same backbone as the policy, it is
# known to be better to warm-start the value network with the reward model.
# Otherwise, the value network should be based on the policy network.
# See https://arxiv.org/abs/2403.17031
value_model = AutoModelForSequenceClassification.from_pretrained(config.sft_model_path, num_labels=1, **model_kwargs)
ref_policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path, **model_kwargs)
policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path, **model_kwargs)
# Set up the reward model and the lambda.
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, num_labels=1, **model_kwargs
)
reward_tokenizer = AutoTokenizer.from_pretrained(
config.reward_model_path,
trust_remote_code=model_config.trust_remote_code,
)
if getattr(reward_tokenizer, "pad_token", None) is None:
reward_tokenizer.pad_token = reward_tokenizer.eos_token
reward_tokenizer.pad_token_id = reward_tokenizer.eos_token_id
reward_model.config.pad_token_id = reward_tokenizer.pad_token_id
#########
# Dataset
#########
system_role_supported = "gemma-2" not in config.sft_model_path
def dataset_map_fn(example, tokenizer):
messages = [{"role": "user", "content": example["prompt"]}]
if system_role_supported: messages.insert(0, {"role": "system", "content": ""})
query = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
r_query = reward_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
example_out = {
"input_ids": tokenizer.encode(query),
# Could also encode but use str to reduce padding in the trainer.
"r_query_text": r_query,
}
return example_out
# We retrieve the dataloader by calling the `build_dataset` function.
# Compute that only on the main process for faster data processing.
# See: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
column_names = ["prompt", "prompt_id", "chosen", "rejected", "messages", "score_chosen", "score_rejected", "source"]
data_config.dataset_mixer = {
"allenai/ultrafeedback_binarized_cleaned": 1.0,
}
raw_datasets = get_datasets(data_config, splits=["train_gen", "test_gen"], columns_to_keep=column_names, shuffle=True)
raw_datasets = raw_datasets.map(
dataset_map_fn,
fn_kwargs={"tokenizer": tokenizer},
remove_columns=column_names,
num_proc=data_config.preprocessing_num_workers,
)
raw_datasets.set_format(type="torch")
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
# Print out 3 random samples.
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]}")
##########
# Training
##########
trainer = PPOTrainerOurs(
config=config,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
reward_tokenizer=reward_tokenizer,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(config.output_dir)
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()
if __name__ == "__main__":
main()
from collections import defaultdict
from dataclasses import dataclass
import gc
import math
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from accelerate.utils import gather_object
from datasets import Dataset
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
DataCollatorWithPadding,
GenerationConfig,
PreTrainedTokenizer,
TrainerCallback,
)
from trl.core import masked_mean, masked_whiten
from trl.models.utils import unwrap_model_for_generation
from trl.trainer.ppov2_trainer import INVALID_LOGPROB, PPOv2Config, PPOv2Trainer
from trl.trainer.utils import (
batch_generation,
first_true_indices,
forward,
get_reward,
print_rich_table,
truncate_response,
)
@dataclass
class RewardDataCollatorWithPadding:
tokenizer: PreTrainedTokenizer
padding: Union[bool, str] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
p_features = []
r_queries = []
for feature in features:
p_features.append({"input_ids": feature["input_ids"]})
r_queries.append(feature["r_query_text"])
p_batch = self.tokenizer.pad(
p_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch = {
"input_ids": p_batch["input_ids"],
"r_query_text": r_queries,
}
return batch
def get_score(
model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the reward logits and the rewards for a given model and query responses.
"""
attention_mask = query_responses != pad_token_id
position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum
lm_backbone = getattr(model, model.base_model_prefix)
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
output = lm_backbone(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
# use_cache=False, # otherwise mistral-based RM would error out
)
reward_logits = model.score(output.hidden_states[-1])
sequence_lengths = torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454
return (
reward_logits,
reward_logits[
torch.arange(reward_logits.size(0), device=reward_logits.device),
sequence_lengths,
].squeeze(-1),
sequence_lengths,
)
class PPOTrainerOurs(PPOv2Trainer):
def __init__(
self,
config: PPOv2Config,
tokenizer: PreTrainedTokenizer,
policy: nn.Module,
ref_policy: nn.Module,
reward_model: nn.Module,
reward_tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
value_model: Optional[nn.Module] = None,
data_collator: Optional[DataCollatorWithPadding] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
# less commonly used
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
callbacks: Optional[List[TrainerCallback]] = None,
) -> None:
data_collator = data_collator or RewardDataCollatorWithPadding(tokenizer=tokenizer)
super().__init__(
config=config,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
train_dataset=train_dataset,
value_model=value_model,
data_collator=data_collator,
eval_dataset=eval_dataset,
optimizers=optimizers,
callbacks=callbacks,
)
self.reward_model = reward_model
self.reward_tokenizer = reward_tokenizer
def train(self):
args = self.args
accelerator = self.accelerator
optimizer = self.optimizer
model = self.model
ref_policy = self.ref_policy
r_model = self.reward_model
tokenizer = self.tokenizer
r_tokenizer = self.reward_tokenizer
dataloader = self.dataloader
device = accelerator.device
def repeat_generator():
while True:
yield from dataloader
iter_dataloader = iter(repeat_generator())
generation_config = GenerationConfig(
max_new_tokens=args.response_length,
min_new_tokens=args.response_length,
temperature=(args.temperature + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)
accelerator.print("===training policy===")
start_time = time.time()
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
pg_loss_stats = torch.zeros(stats_shape, device=device)
vf_loss_stats = torch.zeros(stats_shape, device=device)
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
entropy_stats = torch.zeros(stats_shape, device=device)
ratio_stats = torch.zeros(stats_shape, device=device)
model.train()
# trainer state initialization
self.state.global_step = 0
self.state.episode = 0
self.state.max_steps = args.num_total_batches * args.num_mini_batches
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
for update in range(1, args.num_total_batches + 1):
self.state.episode += 1 * args.batch_size
data = next(iter_dataloader)
with torch.no_grad():
queries = data["input_ids"].to(device)
r_query_texts = data["r_query_text"]
context_length = queries.shape[1]
responses = []
postprocessed_responses = []
logprobs = []
ref_logprobs = []
scores = []
sequence_lengths = []
values = []
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model.policy,
queries,
args.local_rollout_forward_batch_size,
tokenizer.pad_token_id,
generation_config,
)
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
r_query_text = r_query_texts[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del logits, all_logprob
torch.cuda.empty_cache()
ref_output = forward(ref_policy, query_response, tokenizer.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
torch.cuda.empty_cache()
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, tokenizer.pad_token_id, response
)
# Response Processing 2. run reward model on the truncated responses
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1
unwrapped_value_model = accelerator.unwrap_model(model).value_model
full_value, _, _ = get_reward(
unwrapped_value_model, query_response, tokenizer.pad_token_id, context_length
)
value = full_value[:, context_length - 1 : -1].squeeze(-1)
# TODO(kykim): May be not skip_special_tokens but remove pad symbols afterwards.
postprocessed_response_text = tokenizer.batch_decode(postprocessed_response, skip_special_tokens=True)
r_query_response_text = [q + r for q, r in zip(r_query_text, postprocessed_response_text)]
r_query_response_input_ids = r_tokenizer(r_query_response_text, padding=True)["input_ids"]
r_query_response = torch.tensor(r_query_response_input_ids).to(device)
_, score, _ = get_score(r_model, r_query_response, r_tokenizer.pad_token_id)
responses.append(response)
postprocessed_responses.append(postprocessed_response)
logprobs.append(logprob)
ref_logprobs.append(ref_logprob)
sequence_lengths.append(sequence_length)
scores.append(score)
values.append(value)
responses = torch.cat(responses, 0)
postprocessed_responses = torch.cat(postprocessed_responses, 0)
logprobs = torch.cat(logprobs, 0)
ref_logprobs = torch.cat(ref_logprobs, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
values = torch.cat(values, 0)
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
torch.cuda.empty_cache()
gc.collect()
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
# responses not passing that filter will receive a low (fixed) score
# only query humans on responses that pass that filter
contain_eos_token = torch.any(postprocessed_responses[:, context_length:] == tokenizer.eos_token_id, dim=-1)
if args.non_eos_penalty:
scores = torch.where(contain_eos_token, scores, args.penalty_reward_value)
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
sequence_lengths_p1 = sequence_lengths + 1
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
values = torch.masked_fill(values, padding_mask_p1, 0)
# 4. compute rewards
kl = logprobs - ref_logprobs
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores
# 5. whiten rewards
if args.whiten_rewards:
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
# 6. compute advantages and returns
lastgaelam = 0
advantages_reversed = []
gen_length = responses.shape[1]
for t in reversed(range(gen_length)):
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
lastgaelam = delta + args.gamma * args.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=1)
returns = advantages + values
advantages = masked_whiten(advantages, ~padding_mask)
advantages = torch.masked_fill(advantages, padding_mask, 0)
torch.cuda.empty_cache()
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
for ppo_epoch_idx in range(args.num_ppo_epochs):
b_inds = np.random.permutation(args.local_batch_size)
minibatch_idx = 0
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
mini_batch_end = mini_batch_start + args.local_mini_batch_size
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
gradient_accumulation_idx = 0
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
with accelerator.accumulate(model):
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
mb_advantage = advantages[micro_batch_inds]
mb_responses = responses[micro_batch_inds]
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]
mb_return = returns[micro_batch_inds]
mb_values = values[micro_batch_inds]
output, vpred_temp = forward(model, mb_query_responses, tokenizer.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
vpredclipped = torch.clamp(
vpred,
mb_values - args.cliprange_value,
mb_values + args.cliprange_value,
)
vf_losses1 = torch.square(vpred - mb_return)
vf_losses2 = torch.square(vpredclipped - mb_return)
vf_loss_max = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
vf_clipfrac = masked_mean(
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
loss = pg_loss + args.vf_coef * vf_loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
pg_clipfrac = masked_mean(
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
)
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[
ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx
] = pg_clipfrac
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[
ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx
] = vf_clipfrac
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
# del everything and empty cache
# fmt: off
del (
output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped,
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
torch.cuda.empty_cache()
with torch.no_grad():
mean_kl = kl.sum(1).mean()
mean_entropy = (-logprobs).sum(1).mean()
mean_non_score_reward = non_score_reward.sum(1).mean()
rlhf_reward = mean_non_score_reward + scores.mean()
eps = int(self.state.episode / (time.time() - start_time))
metrics = {}
metrics["eps"] = eps
metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item()
metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item()
metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item()
metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item()
metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item()
metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item()
metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item()
metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item()
metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item()
metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item()
metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item()
metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item()
metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item()
metrics["val/num_eos_tokens"] = (responses == tokenizer.eos_token_id).sum().item()
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["episode"] = self.state.episode
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
self.state.global_step += 1
self.log(metrics)
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
torch.cuda.empty_cache()
gc.collect()
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
self.generate_completions(sampling=True)
torch.cuda.empty_cache()
del (
query_responses,
responses,
postprocessed_responses,
logprobs,
ref_logprobs,
values,
sequence_lengths,
contain_eos_token,
sequence_lengths_p1,
response_idxs,
padding_mask,
padding_mask_p1,
rewards,
actual_start,
actual_end,
advantages,
returns,
)
torch.cuda.empty_cache()
# HF trainer specifics
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def generate_completions(self, sampling: bool = False):
args = self.args
tokenizer = self.tokenizer
r_model = self.reward_model
r_tokenizer = self.reward_tokenizer
device = self.accelerator.device
generation_config = GenerationConfig(
max_new_tokens=self.args.response_length,
temperature=(0.01 + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)
table = defaultdict(list)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
for batch in self.eval_dataloader:
query = batch["input_ids"]
r_query_text = batch["r_query_text"]
with torch.no_grad():
context_length = query.shape[1]
query_response, _ = batch_generation(
unwrapped_model.policy,
query,
query.shape[0],
tokenizer.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, tokenizer.pad_token_id, response
)
table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True)))
table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response)))
postprocessed_response_text = tokenizer.batch_decode(postprocessed_response, skip_special_tokens=True)
r_query_response_text = [q + r for q, r in zip(r_query_text, postprocessed_response_text)]
r_query_response = torch.tensor(r_tokenizer(r_query_response_text, padding=True)["input_ids"]).to(device)
_, score, _ = get_score(r_model, r_query_response, r_tokenizer.pad_token_id)
table["score"].extend(self.accelerator.gather(score).float().cpu().numpy())
if sampling:
break
df = pd.DataFrame(table)
if self.accelerator.is_main_process:
print_rich_table(df.iloc[0 : 0 + 5])
if "wandb" in args.report_to:
import wandb
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment