Last active
September 15, 2024 05:29
-
-
Save ericflo/880984bc67a8f205ba5c25cb1dd97ae2 to your computer and use it in GitHub Desktop.
RLAIF Steering Tokens
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create_pairs.py | |
import argparse | |
import copy | |
import json | |
import random | |
from tqdm import tqdm | |
from datasets import load_dataset | |
from vllm import LLM, SamplingParams | |
from transformers import AutoTokenizer | |
def create_pairs( | |
dataset_name: str, | |
model_name: str, | |
tokenizer_name: str, | |
pairs_per_example: int, | |
total_pairs: int, | |
min_length: int, | |
max_length: int, | |
seed: int, | |
offset: int, | |
output_filepath: str, | |
): | |
random.seed(seed) | |
llm = LLM(model=model_name) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
dataset = ( | |
load_dataset(dataset_name, streaming=True)["train"] | |
.shuffle(seed=seed) | |
.skip(offset) | |
) | |
dataset_iter = iter(dataset) | |
progress = tqdm(total=total_pairs) | |
progress.display(f"Processing dataset {dataset_name}...") | |
pairs = [] | |
while len(pairs) < total_pairs: | |
example = next(dataset_iter) | |
messages = example.get("conversations") | |
source = example.get("source") | |
if not source or not messages: | |
raise ValueError(f"Missing source or conversations: {example}") | |
fmt_msgs = [ | |
{ | |
"role": { | |
"system": "system", | |
"gpt": "assistant", | |
"human": "user", | |
}[msg["from"]], | |
"content": msg["value"], | |
} | |
for msg in messages | |
] | |
history, last_message = (fmt_msgs[:-1], fmt_msgs[-1]) | |
if last_message["role"] != "assistant": | |
raise ValueError(f"Last message is not from user: {example}") | |
prefix = ( | |
tokenizer.apply_chat_template( | |
history, add_generation_prompt=True, tokenize=False | |
) | |
+ "<|steer_start|>" | |
) | |
gen_token_count = random.randrange(min_length, max_length) | |
# TODO: add system message requesting `gen_token_count` of thinking tokens | |
sampling_params = SamplingParams( | |
temperature=1.0, | |
min_tokens=gen_token_count, | |
max_tokens=gen_token_count, | |
n=pairs_per_example * 2, | |
) | |
output = llm.generate([prefix], sampling_params, use_tqdm=False)[0] | |
ex = { | |
"prompt": history, | |
"steer1": None, | |
"steer2": None, | |
"completion": [last_message], | |
} | |
for op in output.outputs: | |
txt = ( | |
op.text.strip() | |
.replace("<|steer_start|>", "") | |
.replace("<|steer_end|>", "") | |
.strip() | |
) | |
if ex["steer1"] is None: | |
ex["steer1"] = txt | |
elif ex["steer2"] is None: | |
ex["steer2"] = txt | |
pairs.append(copy.deepcopy(ex)) | |
progress.update(len(pairs) - progress.n) | |
ex.update({"steer1": None, "steer2": None}) | |
else: | |
raise ValueError("Unexpected output") | |
with open(output_filepath, "w") as f: | |
for pair in pairs: | |
f.write(json.dumps(pair) + "\n") | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dataset_name", type=str, default="mlabonne/FineTome-100k") | |
parser.add_argument( | |
"--model_name", type=str, default="meta-llama/Meta-Llama-3.1-8B" | |
) | |
parser.add_argument( | |
"--tokenizer_name", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct" | |
) | |
parser.add_argument("--pairs_per_example", type=int, default=10) | |
parser.add_argument("--total_pairs", type=int, default=1000) | |
parser.add_argument("--min_length", type=int, default=32) | |
parser.add_argument("--max_length", type=int, default=2048) | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("--offset", type=int, default=0) | |
parser.add_argument("--output_filepath", type=str, default="pairs_01.jsonl") | |
args = parser.parse_args() | |
create_pairs( | |
dataset_name=args.dataset_name, | |
model_name=args.model_name, | |
tokenizer_name=args.tokenizer_name, | |
pairs_per_example=args.pairs_per_example, | |
total_pairs=args.total_pairs, | |
min_length=args.min_length, | |
max_length=args.max_length, | |
seed=args.seed, | |
offset=args.offset, | |
output_filepath=args.output_filepath, | |
) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# merge.py | |
import argparse | |
import shutil | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModelForCausalLM | |
def merge( | |
base_model_name: str, | |
chat_template_tokenizer_name: str, | |
checkpoint_path: str, | |
output_path: str, | |
): | |
shutil.rmtree(output_path, ignore_errors=True) | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) | |
tokenizer.chat_template = AutoTokenizer.from_pretrained( | |
chat_template_tokenizer_name | |
).chat_template | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, device_map="auto" | |
) | |
base_model.resize_token_embeddings(len(tokenizer)) | |
model = PeftModelForCausalLM.from_pretrained( | |
base_model, checkpoint_path, device_map="auto" | |
) | |
merged = model.merge_and_unload() | |
merged.save_pretrained(output_path) | |
tokenizer.save_pretrained(output_path) | |
def main(): | |
parser = argparse.ArgumentParser() | |
# Add an argument for the base model name | |
parser.add_argument( | |
"--base_model_name", | |
type=str, | |
default="meta-llama/Meta-Llama-3.1-8B", | |
help="The name of the base model", | |
) | |
# Add an argument for the chat template tokenizer name | |
parser.add_argument( | |
"--chat_template_tokenizer_name", | |
type=str, | |
default="NousResearch/Hermes-3-Llama-3.1-8B", | |
help="The name of the chat template tokenizer", | |
) | |
# Add an argument for the checkpoint path | |
parser.add_argument( | |
"--checkpoint_path", | |
type=str, | |
default="steer_round1", | |
help="The path to the checkpoint", | |
) | |
# Add an argument for the output path | |
parser.add_argument( | |
"--output_path", | |
type=str, | |
default="steer_round1/merged", | |
help="The path to save the merged model", | |
) | |
args = parser.parse_args() | |
# Call the merge function with the provided arguments | |
merge( | |
args.base_model_name, | |
args.chat_template_tokenizer_name, | |
args.checkpoint_path, | |
args.output_path, | |
) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import random | |
import torch | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
def compute_perplexity(prompt, prefix, completion, tokenizer, model): | |
""" | |
Compute the perplexity of the completion given the prompt and prefix. | |
""" | |
# Concatenate the prompt, prefix, and completion | |
full_input = prompt + prefix + completion | |
# Tokenize the full input and get input IDs | |
input_ids = tokenizer(full_input, return_tensors="pt").input_ids | |
# Determine the length of the prompt + prefix to isolate the completion tokens | |
prompt_prefix_ids = tokenizer(prompt + prefix, return_tensors="pt").input_ids | |
prefix_length = prompt_prefix_ids.shape[1] | |
# Get the model's logits (predictions before softmax) | |
with torch.no_grad(): | |
outputs = model(input_ids=input_ids.to(device=model.device)) | |
logits = outputs.logits | |
# Shift logits and labels to align for loss computation | |
shift_logits = logits[:, :-1, :].squeeze(0) | |
shift_labels = input_ids[:, 1:].squeeze(0) | |
# Calculate the start index for the completion tokens in the shifted labels | |
completion_start = prefix_length - 1 # Adjust for shifted labels | |
# Extract logits and labels for the completion tokens | |
completion_logits = shift_logits[completion_start:, :] | |
completion_labels = shift_labels[completion_start:] | |
# Compute log probabilities for the completion tokens | |
log_probs = F.log_softmax(completion_logits, dim=-1) | |
completion_log_probs = log_probs[ | |
range(completion_labels.shape[0]), completion_labels | |
] | |
# Calculate the negative log-likelihood and perplexity | |
nll = -completion_log_probs.sum() | |
perplexity = torch.exp(nll / completion_labels.shape[0]) | |
return perplexity.item() | |
def rank_pairs( | |
input_filepath: str, | |
output_filepath: str, | |
model_name: str, | |
tokenizer_name: str, | |
seed: int, | |
): | |
random.seed(seed) | |
with open(input_filepath) as f: | |
data = [json.loads(line.strip()) for line in f if line.strip()] | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
attn_implementation="flash_attention_2", | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
output = [] | |
for row in tqdm(data): | |
prompt: list = row["prompt"] | |
steer_1: str = row["steer1"] | |
steer_2: str = row["steer2"] | |
completion: list = row["completion"] | |
steer_1_chunk = f"<|steer_start|>{steer_1}<|steer_end|>" | |
steer_2_chunk = f"<|steer_start|>{steer_2}<|steer_end|>" | |
steer_1_pre_msg = {"role": "assistant", "content": steer_1_chunk} | |
steer_2_pre_msg = {"role": "assistant", "content": steer_2_chunk} | |
steer_1_msg = { | |
"role": "assistant", | |
"content": steer_1_chunk + completion[0]["content"], | |
} | |
steer_2_msg = { | |
"role": "assistant", | |
"content": steer_2_chunk + completion[0]["content"], | |
} | |
prompt_tmpl = tokenizer.apply_chat_template( | |
prompt, tokenize=False, add_generation_prompt=True | |
) | |
steer_1_pre_tmp = tokenizer.apply_chat_template( | |
prompt + [steer_1_pre_msg], tokenize=False | |
) | |
if steer_1_pre_tmp.endswith("<|eot_id|>"): | |
steer_1_pre_tmp = steer_1_pre_tmp[: -len("<|eot_id|>")] | |
steer_1_pre_count = len(steer_1_pre_tmp) | |
steer_1_tmpl = tokenizer.apply_chat_template( | |
prompt + [steer_1_msg], tokenize=False | |
)[steer_1_pre_count:] | |
steer_2_pre_tmp = tokenizer.apply_chat_template( | |
prompt + [steer_2_pre_msg], tokenize=False | |
) | |
if steer_2_pre_tmp.endswith("<|eot_id|>"): | |
steer_2_pre_tmp = steer_2_pre_tmp[: -len("<|eot_id|>")] | |
steer_2_pre_count = len(steer_2_pre_tmp) | |
steer_2_tmpl = tokenizer.apply_chat_template( | |
prompt + [steer_2_msg], tokenize=False | |
)[steer_2_pre_count:] | |
perplexity_1 = compute_perplexity( | |
prompt_tmpl, steer_1_chunk, steer_1_tmpl, tokenizer, model | |
) | |
perplexity_2 = compute_perplexity( | |
prompt_tmpl, steer_2_chunk, steer_2_tmpl, tokenizer, model | |
) | |
better_1 = perplexity_1 < perplexity_2 | |
output.append( | |
{"prompt": prompt, "completion": [steer_1_pre_msg], "label": better_1} | |
) | |
output.append( | |
{"prompt": prompt, "completion": [steer_2_pre_msg], "label": not better_1} | |
) | |
with open(output_filepath, "w") as f: | |
for out in output: | |
f.write(json.dumps(out) + "\n") | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input_filepath", type=str, default="pairs_01.jsonl") | |
parser.add_argument("--output_filepath", type=str, default="kto.jsonl") | |
parser.add_argument( | |
"--model_name", type=str, default="meta-llama/Meta-Llama-3.1-8B" | |
) | |
parser.add_argument( | |
"--tokenizer_name", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct" | |
) | |
parser.add_argument("--seed", type=int, default=42) | |
args = parser.parse_args() | |
rank_pairs( | |
input_filepath=args.input_filepath, | |
output_filepath=args.output_filepath, | |
model_name=args.model_name, | |
tokenizer_name=args.tokenizer_name, | |
seed=args.seed, | |
) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env bash | |
# pip install -U pip | |
# pip install -U packaging wheel | |
# pip install -U torch | |
# pip install -U transformers accelerate trl wandb peft bitsandbytes liger-kernel flash_attn sentencepiece hf-transfer vllm | |
# pip install -U -e git+https://github.com/huggingface/trl.git#egg=trl | |
export WANDB_PROJECT=steering | |
# Common variables | |
DATASET_NAME="mlabonne/FineTome-100k" | |
TOKENIZER_NAME="meta-llama/Meta-Llama-3.1-8B-Instruct" | |
RANK_MODEL_NAME="meta-llama/Meta-Llama-3.1-8B-Instruct" | |
BASE_MODEL_NAME="meta-llama/Meta-Llama-3.1-8B" | |
PAIRS_PER_EXAMPLE=10 | |
TOTAL_PAIRS=1000 | |
MIN_LENGTH=32 | |
MAX_LENGTH=64 | |
SEED=42 | |
LR=0.00001 | |
MAX_LENGTH_MODEL=2176 | |
PROMPT_LENGTH=2048 | |
COMPLETION_LENGTH=128 | |
BATCH_SIZE=1 | |
GRAD_ACCUM_STEPS=24 | |
EPOCHS=1 | |
LORA_R=8 | |
LORA_ALPHA=16 | |
LORA_DROPOUT=0.05 | |
TORCH_DTYPE="bfloat16" | |
# Function to run one iteration | |
run_iteration() { | |
local N=$1 | |
local OFFSET=$2 | |
local MODEL_NAME=$3 | |
local PREV_ROUND=$((N - 1)) | |
local OUTPUT_DIR="steer_round${N}" | |
local RUN_NAME="llama-3.1-8b-steer-round$(printf "%02d" ${N})" | |
local PAIRS_FILE=$(printf "pairs_%02d.jsonl" ${N}) | |
local KTO_FILE="kto${N}.jsonl" | |
echo "Running iteration $N" | |
# Create pairs | |
python create_pairs.py \ | |
--dataset_name "$DATASET_NAME" \ | |
--model_name "$MODEL_NAME" \ | |
--tokenizer_name "$TOKENIZER_NAME" \ | |
--pairs_per_example "$PAIRS_PER_EXAMPLE" \ | |
--total_pairs "$TOTAL_PAIRS" \ | |
--min_length "$MIN_LENGTH" \ | |
--max_length "$MAX_LENGTH" \ | |
--seed "$SEED" \ | |
--offset "$OFFSET" \ | |
--output_filepath "$PAIRS_FILE" | |
# Rank pairs | |
python rank_pairs_kto.py \ | |
--input_filepath "$PAIRS_FILE" \ | |
--output_filepath "$KTO_FILE" \ | |
--model_name "$RANK_MODEL_NAME" \ | |
--tokenizer_name "$RANK_MODEL_NAME" \ | |
--seed "$SEED" | |
# Train model | |
python train.py \ | |
--run_name="$RUN_NAME" \ | |
--model_name_or_path="$MODEL_NAME" \ | |
--dataset_name="$KTO_FILE" \ | |
--report_to="wandb" \ | |
--optim="adamw_bnb_8bit" \ | |
--lr_scheduler_type="cosine" \ | |
--learning_rate="$LR" \ | |
--max_length "$MAX_LENGTH_MODEL" \ | |
--max_prompt_length "$PROMPT_LENGTH" \ | |
--max_completion_length "$COMPLETION_LENGTH" \ | |
--remove_unused_columns=False \ | |
--attn_implementation="flash_attention_2" \ | |
--save_strategy="steps" \ | |
--save_steps 50 \ | |
--save_total_limit=10 \ | |
--per_device_train_batch_size="$BATCH_SIZE" \ | |
--per_device_eval_batch_size="$BATCH_SIZE" \ | |
--gradient_accumulation_steps="$GRAD_ACCUM_STEPS" \ | |
--logging_steps=1 \ | |
--num_train_epochs="$EPOCHS" \ | |
--gradient_checkpointing \ | |
--use_peft \ | |
--lora_r="$LORA_R" \ | |
--lora_alpha="$LORA_ALPHA" \ | |
--lora_dropout="$LORA_DROPOUT" \ | |
--torch_dtype="$TORCH_DTYPE" \ | |
--output_dir="$OUTPUT_DIR" | |
# Merge checkpoints | |
python merge.py \ | |
--base_model_name "$MODEL_NAME" \ | |
--chat_template_tokenizer_name "$TOKENIZER_NAME" \ | |
--checkpoint_path "$OUTPUT_DIR" \ | |
--output_path "$OUTPUT_DIR/merged" | |
} | |
# First iteration | |
N=1 | |
OFFSET=0 | |
MODEL_NAME="$BASE_MODEL_NAME" | |
run_iteration $N $OFFSET "$MODEL_NAME" | |
# Prepare for loop | |
N=2 | |
while true; do | |
OFFSET=$(((N - 1) * 100)) | |
PREV_ROUND=$((N - 1)) | |
MODEL_NAME="steer_round${PREV_ROUND}/merged" | |
run_iteration $N $OFFSET "$MODEL_NAME" | |
N=$((N + 1)) | |
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from accelerate import PartialState | |
from datasets import load_dataset | |
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser | |
import torch | |
from trl import ( | |
KTOConfig, | |
KTOTrainer, | |
ModelConfig, | |
get_peft_config, | |
maybe_unpair_preference_dataset, | |
setup_chat_format, | |
get_kbit_device_map, | |
get_quantization_config, | |
) | |
# Define and parse arguments. | |
@dataclass | |
class ScriptArguments: | |
""" | |
The arguments for the KTO training script. | |
""" | |
dataset_name: str = "trl-lib/kto-mix-14k" | |
if __name__ == "__main__": | |
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig)) | |
script_args, kto_args, model_args = parser.parse_args_into_dataclasses() | |
torch_dtype = ( | |
model_args.torch_dtype | |
if model_args.torch_dtype in ["auto", None] | |
else getattr(torch, model_args.torch_dtype) | |
) | |
quantization_config = get_quantization_config(model_args) | |
model_kwargs = dict( | |
revision=model_args.model_revision, | |
attn_implementation=model_args.attn_implementation, | |
torch_dtype=torch_dtype, | |
use_cache=False if kto_args.gradient_checkpointing else True, | |
device_map=get_kbit_device_map() if quantization_config is not None else None, | |
quantization_config=quantization_config, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_args.model_name_or_path, | |
trust_remote_code=model_args.trust_remote_code, | |
**model_kwargs | |
) | |
peft_config = get_peft_config(model_args) | |
if peft_config is None: | |
ref_model = AutoModelForCausalLM.from_pretrained( | |
model_args.model_name_or_path, | |
trust_remote_code=model_args.trust_remote_code, | |
**model_kwargs | |
) | |
else: | |
ref_model = None | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# If we are aligning a base model, we use ChatML as the default template | |
if tokenizer.chat_template is None: | |
model, tokenizer = setup_chat_format(model, tokenizer) | |
# model.resize_token_embeddings(len(tokenizer)) | |
# Load the dataset | |
# dataset = load_dataset(script_args.dataset_name) | |
dataset = load_dataset("json", data_files=[script_args.dataset_name])["train"] | |
dataset = dataset.train_test_split(test_size=0.025, seed=42) | |
# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label) | |
dataset = maybe_unpair_preference_dataset( | |
dataset, num_proc=kto_args.dataset_num_proc | |
) | |
# Apply chat template | |
def format_dataset(example): | |
# print(example) | |
# print("-------") | |
# print(example["prompt"]) | |
example["prompt"] = tokenizer.apply_chat_template( | |
example["prompt"], tokenize=False | |
) | |
example["completion"] = tokenizer.apply_chat_template( | |
example["completion"], tokenize=False | |
) | |
# if isinstance(example["completion"], str): | |
# example["prompt"] = tokenizer.apply_chat_template( | |
# example["prompt"], tokenize=False | |
# ) | |
# example["completion"] = tokenizer.apply_chat_template( | |
# example["completion"], tokenize=False | |
# ) | |
# else: | |
# example["prompt"] = tokenizer.apply_chat_template( | |
# example["completion"][:-1], tokenize=False | |
# ) | |
# example["completion"] = tokenizer.apply_chat_template( | |
# [example["completion"][-1]], tokenize=False | |
# ) | |
return example | |
# Compute that only on the main process for faster data processing. | |
# see: https://github.com/huggingface/trl/pull/1255 | |
with PartialState().local_main_process_first(): | |
formatted_dataset = dataset.map( | |
format_dataset, num_proc=kto_args.dataset_num_proc | |
) | |
# Initialize the KTO trainer | |
kto_trainer = KTOTrainer( | |
model, | |
ref_model, | |
args=kto_args, | |
train_dataset=formatted_dataset["train"], | |
eval_dataset=formatted_dataset["test"], | |
tokenizer=tokenizer, | |
peft_config=get_peft_config(model_args), | |
) | |
# Train and push the model to the Hub | |
kto_trainer.train() | |
kto_trainer.save_model(kto_args.output_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment