Skip to content

Instantly share code, notes, and snippets.

@ericflo
Last active September 15, 2024 05:29
Show Gist options
  • Save ericflo/880984bc67a8f205ba5c25cb1dd97ae2 to your computer and use it in GitHub Desktop.
Save ericflo/880984bc67a8f205ba5c25cb1dd97ae2 to your computer and use it in GitHub Desktop.
RLAIF Steering Tokens
# create_pairs.py
import argparse
import copy
import json
import random
from tqdm import tqdm
from datasets import load_dataset
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
def create_pairs(
dataset_name: str,
model_name: str,
tokenizer_name: str,
pairs_per_example: int,
total_pairs: int,
min_length: int,
max_length: int,
seed: int,
offset: int,
output_filepath: str,
):
random.seed(seed)
llm = LLM(model=model_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
dataset = (
load_dataset(dataset_name, streaming=True)["train"]
.shuffle(seed=seed)
.skip(offset)
)
dataset_iter = iter(dataset)
progress = tqdm(total=total_pairs)
progress.display(f"Processing dataset {dataset_name}...")
pairs = []
while len(pairs) < total_pairs:
example = next(dataset_iter)
messages = example.get("conversations")
source = example.get("source")
if not source or not messages:
raise ValueError(f"Missing source or conversations: {example}")
fmt_msgs = [
{
"role": {
"system": "system",
"gpt": "assistant",
"human": "user",
}[msg["from"]],
"content": msg["value"],
}
for msg in messages
]
history, last_message = (fmt_msgs[:-1], fmt_msgs[-1])
if last_message["role"] != "assistant":
raise ValueError(f"Last message is not from user: {example}")
prefix = (
tokenizer.apply_chat_template(
history, add_generation_prompt=True, tokenize=False
)
+ "<|steer_start|>"
)
gen_token_count = random.randrange(min_length, max_length)
# TODO: add system message requesting `gen_token_count` of thinking tokens
sampling_params = SamplingParams(
temperature=1.0,
min_tokens=gen_token_count,
max_tokens=gen_token_count,
n=pairs_per_example * 2,
)
output = llm.generate([prefix], sampling_params, use_tqdm=False)[0]
ex = {
"prompt": history,
"steer1": None,
"steer2": None,
"completion": [last_message],
}
for op in output.outputs:
txt = (
op.text.strip()
.replace("<|steer_start|>", "")
.replace("<|steer_end|>", "")
.strip()
)
if ex["steer1"] is None:
ex["steer1"] = txt
elif ex["steer2"] is None:
ex["steer2"] = txt
pairs.append(copy.deepcopy(ex))
progress.update(len(pairs) - progress.n)
ex.update({"steer1": None, "steer2": None})
else:
raise ValueError("Unexpected output")
with open(output_filepath, "w") as f:
for pair in pairs:
f.write(json.dumps(pair) + "\n")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="mlabonne/FineTome-100k")
parser.add_argument(
"--model_name", type=str, default="meta-llama/Meta-Llama-3.1-8B"
)
parser.add_argument(
"--tokenizer_name", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct"
)
parser.add_argument("--pairs_per_example", type=int, default=10)
parser.add_argument("--total_pairs", type=int, default=1000)
parser.add_argument("--min_length", type=int, default=32)
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--offset", type=int, default=0)
parser.add_argument("--output_filepath", type=str, default="pairs_01.jsonl")
args = parser.parse_args()
create_pairs(
dataset_name=args.dataset_name,
model_name=args.model_name,
tokenizer_name=args.tokenizer_name,
pairs_per_example=args.pairs_per_example,
total_pairs=args.total_pairs,
min_length=args.min_length,
max_length=args.max_length,
seed=args.seed,
offset=args.offset,
output_filepath=args.output_filepath,
)
if __name__ == "__main__":
main()
# merge.py
import argparse
import shutil
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModelForCausalLM
def merge(
base_model_name: str,
chat_template_tokenizer_name: str,
checkpoint_path: str,
output_path: str,
):
shutil.rmtree(output_path, ignore_errors=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
tokenizer.chat_template = AutoTokenizer.from_pretrained(
chat_template_tokenizer_name
).chat_template
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name, device_map="auto"
)
base_model.resize_token_embeddings(len(tokenizer))
model = PeftModelForCausalLM.from_pretrained(
base_model, checkpoint_path, device_map="auto"
)
merged = model.merge_and_unload()
merged.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
def main():
parser = argparse.ArgumentParser()
# Add an argument for the base model name
parser.add_argument(
"--base_model_name",
type=str,
default="meta-llama/Meta-Llama-3.1-8B",
help="The name of the base model",
)
# Add an argument for the chat template tokenizer name
parser.add_argument(
"--chat_template_tokenizer_name",
type=str,
default="NousResearch/Hermes-3-Llama-3.1-8B",
help="The name of the chat template tokenizer",
)
# Add an argument for the checkpoint path
parser.add_argument(
"--checkpoint_path",
type=str,
default="steer_round1",
help="The path to the checkpoint",
)
# Add an argument for the output path
parser.add_argument(
"--output_path",
type=str,
default="steer_round1/merged",
help="The path to save the merged model",
)
args = parser.parse_args()
# Call the merge function with the provided arguments
merge(
args.base_model_name,
args.chat_template_tokenizer_name,
args.checkpoint_path,
args.output_path,
)
if __name__ == "__main__":
main()
import argparse
import json
import random
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
def compute_perplexity(prompt, prefix, completion, tokenizer, model):
"""
Compute the perplexity of the completion given the prompt and prefix.
"""
# Concatenate the prompt, prefix, and completion
full_input = prompt + prefix + completion
# Tokenize the full input and get input IDs
input_ids = tokenizer(full_input, return_tensors="pt").input_ids
# Determine the length of the prompt + prefix to isolate the completion tokens
prompt_prefix_ids = tokenizer(prompt + prefix, return_tensors="pt").input_ids
prefix_length = prompt_prefix_ids.shape[1]
# Get the model's logits (predictions before softmax)
with torch.no_grad():
outputs = model(input_ids=input_ids.to(device=model.device))
logits = outputs.logits
# Shift logits and labels to align for loss computation
shift_logits = logits[:, :-1, :].squeeze(0)
shift_labels = input_ids[:, 1:].squeeze(0)
# Calculate the start index for the completion tokens in the shifted labels
completion_start = prefix_length - 1 # Adjust for shifted labels
# Extract logits and labels for the completion tokens
completion_logits = shift_logits[completion_start:, :]
completion_labels = shift_labels[completion_start:]
# Compute log probabilities for the completion tokens
log_probs = F.log_softmax(completion_logits, dim=-1)
completion_log_probs = log_probs[
range(completion_labels.shape[0]), completion_labels
]
# Calculate the negative log-likelihood and perplexity
nll = -completion_log_probs.sum()
perplexity = torch.exp(nll / completion_labels.shape[0])
return perplexity.item()
def rank_pairs(
input_filepath: str,
output_filepath: str,
model_name: str,
tokenizer_name: str,
seed: int,
):
random.seed(seed)
with open(input_filepath) as f:
data = [json.loads(line.strip()) for line in f if line.strip()]
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2",
device_map="auto",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
output = []
for row in tqdm(data):
prompt: list = row["prompt"]
steer_1: str = row["steer1"]
steer_2: str = row["steer2"]
completion: list = row["completion"]
steer_1_chunk = f"<|steer_start|>{steer_1}<|steer_end|>"
steer_2_chunk = f"<|steer_start|>{steer_2}<|steer_end|>"
steer_1_pre_msg = {"role": "assistant", "content": steer_1_chunk}
steer_2_pre_msg = {"role": "assistant", "content": steer_2_chunk}
steer_1_msg = {
"role": "assistant",
"content": steer_1_chunk + completion[0]["content"],
}
steer_2_msg = {
"role": "assistant",
"content": steer_2_chunk + completion[0]["content"],
}
prompt_tmpl = tokenizer.apply_chat_template(
prompt, tokenize=False, add_generation_prompt=True
)
steer_1_pre_tmp = tokenizer.apply_chat_template(
prompt + [steer_1_pre_msg], tokenize=False
)
if steer_1_pre_tmp.endswith("<|eot_id|>"):
steer_1_pre_tmp = steer_1_pre_tmp[: -len("<|eot_id|>")]
steer_1_pre_count = len(steer_1_pre_tmp)
steer_1_tmpl = tokenizer.apply_chat_template(
prompt + [steer_1_msg], tokenize=False
)[steer_1_pre_count:]
steer_2_pre_tmp = tokenizer.apply_chat_template(
prompt + [steer_2_pre_msg], tokenize=False
)
if steer_2_pre_tmp.endswith("<|eot_id|>"):
steer_2_pre_tmp = steer_2_pre_tmp[: -len("<|eot_id|>")]
steer_2_pre_count = len(steer_2_pre_tmp)
steer_2_tmpl = tokenizer.apply_chat_template(
prompt + [steer_2_msg], tokenize=False
)[steer_2_pre_count:]
perplexity_1 = compute_perplexity(
prompt_tmpl, steer_1_chunk, steer_1_tmpl, tokenizer, model
)
perplexity_2 = compute_perplexity(
prompt_tmpl, steer_2_chunk, steer_2_tmpl, tokenizer, model
)
better_1 = perplexity_1 < perplexity_2
output.append(
{"prompt": prompt, "completion": [steer_1_pre_msg], "label": better_1}
)
output.append(
{"prompt": prompt, "completion": [steer_2_pre_msg], "label": not better_1}
)
with open(output_filepath, "w") as f:
for out in output:
f.write(json.dumps(out) + "\n")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_filepath", type=str, default="pairs_01.jsonl")
parser.add_argument("--output_filepath", type=str, default="kto.jsonl")
parser.add_argument(
"--model_name", type=str, default="meta-llama/Meta-Llama-3.1-8B"
)
parser.add_argument(
"--tokenizer_name", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct"
)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
rank_pairs(
input_filepath=args.input_filepath,
output_filepath=args.output_filepath,
model_name=args.model_name,
tokenizer_name=args.tokenizer_name,
seed=args.seed,
)
if __name__ == "__main__":
main()
#!/usr/bin/env bash
# pip install -U pip
# pip install -U packaging wheel
# pip install -U torch
# pip install -U transformers accelerate trl wandb peft bitsandbytes liger-kernel flash_attn sentencepiece hf-transfer vllm
# pip install -U -e git+https://github.com/huggingface/trl.git#egg=trl
export WANDB_PROJECT=steering
# Common variables
DATASET_NAME="mlabonne/FineTome-100k"
TOKENIZER_NAME="meta-llama/Meta-Llama-3.1-8B-Instruct"
RANK_MODEL_NAME="meta-llama/Meta-Llama-3.1-8B-Instruct"
BASE_MODEL_NAME="meta-llama/Meta-Llama-3.1-8B"
PAIRS_PER_EXAMPLE=10
TOTAL_PAIRS=1000
MIN_LENGTH=32
MAX_LENGTH=64
SEED=42
LR=0.00001
MAX_LENGTH_MODEL=2176
PROMPT_LENGTH=2048
COMPLETION_LENGTH=128
BATCH_SIZE=1
GRAD_ACCUM_STEPS=24
EPOCHS=1
LORA_R=8
LORA_ALPHA=16
LORA_DROPOUT=0.05
TORCH_DTYPE="bfloat16"
# Function to run one iteration
run_iteration() {
local N=$1
local OFFSET=$2
local MODEL_NAME=$3
local PREV_ROUND=$((N - 1))
local OUTPUT_DIR="steer_round${N}"
local RUN_NAME="llama-3.1-8b-steer-round$(printf "%02d" ${N})"
local PAIRS_FILE=$(printf "pairs_%02d.jsonl" ${N})
local KTO_FILE="kto${N}.jsonl"
echo "Running iteration $N"
# Create pairs
python create_pairs.py \
--dataset_name "$DATASET_NAME" \
--model_name "$MODEL_NAME" \
--tokenizer_name "$TOKENIZER_NAME" \
--pairs_per_example "$PAIRS_PER_EXAMPLE" \
--total_pairs "$TOTAL_PAIRS" \
--min_length "$MIN_LENGTH" \
--max_length "$MAX_LENGTH" \
--seed "$SEED" \
--offset "$OFFSET" \
--output_filepath "$PAIRS_FILE"
# Rank pairs
python rank_pairs_kto.py \
--input_filepath "$PAIRS_FILE" \
--output_filepath "$KTO_FILE" \
--model_name "$RANK_MODEL_NAME" \
--tokenizer_name "$RANK_MODEL_NAME" \
--seed "$SEED"
# Train model
python train.py \
--run_name="$RUN_NAME" \
--model_name_or_path="$MODEL_NAME" \
--dataset_name="$KTO_FILE" \
--report_to="wandb" \
--optim="adamw_bnb_8bit" \
--lr_scheduler_type="cosine" \
--learning_rate="$LR" \
--max_length "$MAX_LENGTH_MODEL" \
--max_prompt_length "$PROMPT_LENGTH" \
--max_completion_length "$COMPLETION_LENGTH" \
--remove_unused_columns=False \
--attn_implementation="flash_attention_2" \
--save_strategy="steps" \
--save_steps 50 \
--save_total_limit=10 \
--per_device_train_batch_size="$BATCH_SIZE" \
--per_device_eval_batch_size="$BATCH_SIZE" \
--gradient_accumulation_steps="$GRAD_ACCUM_STEPS" \
--logging_steps=1 \
--num_train_epochs="$EPOCHS" \
--gradient_checkpointing \
--use_peft \
--lora_r="$LORA_R" \
--lora_alpha="$LORA_ALPHA" \
--lora_dropout="$LORA_DROPOUT" \
--torch_dtype="$TORCH_DTYPE" \
--output_dir="$OUTPUT_DIR"
# Merge checkpoints
python merge.py \
--base_model_name "$MODEL_NAME" \
--chat_template_tokenizer_name "$TOKENIZER_NAME" \
--checkpoint_path "$OUTPUT_DIR" \
--output_path "$OUTPUT_DIR/merged"
}
# First iteration
N=1
OFFSET=0
MODEL_NAME="$BASE_MODEL_NAME"
run_iteration $N $OFFSET "$MODEL_NAME"
# Prepare for loop
N=2
while true; do
OFFSET=$(((N - 1) * 100))
PREV_ROUND=$((N - 1))
MODEL_NAME="steer_round${PREV_ROUND}/merged"
run_iteration $N $OFFSET "$MODEL_NAME"
N=$((N + 1))
done
from dataclasses import dataclass
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
import torch
from trl import (
KTOConfig,
KTOTrainer,
ModelConfig,
get_peft_config,
maybe_unpair_preference_dataset,
setup_chat_format,
get_kbit_device_map,
get_quantization_config,
)
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the KTO training script.
"""
dataset_name: str = "trl-lib/kto-mix-14k"
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if kto_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs
)
peft_config = get_peft_config(model_args)
if peft_config is None:
ref_model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs
)
else:
ref_model = None
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# If we are aligning a base model, we use ChatML as the default template
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
# model.resize_token_embeddings(len(tokenizer))
# Load the dataset
# dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset("json", data_files=[script_args.dataset_name])["train"]
dataset = dataset.train_test_split(test_size=0.025, seed=42)
# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
dataset = maybe_unpair_preference_dataset(
dataset, num_proc=kto_args.dataset_num_proc
)
# Apply chat template
def format_dataset(example):
# print(example)
# print("-------")
# print(example["prompt"])
example["prompt"] = tokenizer.apply_chat_template(
example["prompt"], tokenize=False
)
example["completion"] = tokenizer.apply_chat_template(
example["completion"], tokenize=False
)
# if isinstance(example["completion"], str):
# example["prompt"] = tokenizer.apply_chat_template(
# example["prompt"], tokenize=False
# )
# example["completion"] = tokenizer.apply_chat_template(
# example["completion"], tokenize=False
# )
# else:
# example["prompt"] = tokenizer.apply_chat_template(
# example["completion"][:-1], tokenize=False
# )
# example["completion"] = tokenizer.apply_chat_template(
# [example["completion"][-1]], tokenize=False
# )
return example
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
formatted_dataset = dataset.map(
format_dataset, num_proc=kto_args.dataset_num_proc
)
# Initialize the KTO trainer
kto_trainer = KTOTrainer(
model,
ref_model,
args=kto_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)
# Train and push the model to the Hub
kto_trainer.train()
kto_trainer.save_model(kto_args.output_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment