Skip to content

Instantly share code, notes, and snippets.

@xzuyn
Last active July 29, 2024 12:30
Show Gist options
  • Select an option

  • Save xzuyn/87097972ab2323ced81e4d7b41c47a45 to your computer and use it in GitHub Desktop.

Select an option

Save xzuyn/87097972ab2323ced81e4d7b41c47a45 to your computer and use it in GitHub Desktop.
# Modified from this original script:
# https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/examples/scripts/orpo.py
#
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the ORPO training script with the following command with some example arguments.
In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model:
python ORPO_PreferenceShareGPT.py \
--model_name_or_path "meta-llama/Meta-Llama-3-8B-Instruct" \
--dataset "PJMixers/Intel_orca_dpo_pairs-PreferenceShareGPT" \
--output_dir "./LLaMa-3-Instruct-Intel-Orca-ORPO-8B-QDoRA" \
--report_to wandb \
--run_name "LLaMa-3-Instruct-Intel-Orca-ORPO-8B-QDoRA" \
--push_to_hub True \
--hub_strategy "all_checkpoints" \
--hub_private_repo True \
--hub_model_id "xzuyn/LLaMa-3-Instruct-Intel-Orca-ORPO-8B-QDoRA" \
--num_train_epochs 1 \
--max_length 2048 \
--max_prompt_length 2048 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing True \
--learning_rate 0.00001 \
--lr_scheduler_type "cosine" \
--beta 0.1 \
--weight_decay 0.1 \
--max_grad_norm 1 \
--logging_steps 1 \
--warmup_steps 100 \
--save_strategy "steps" \
--save_steps 100 \
--optim paged_adamw_8bit \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--save_total_limit 2 \
--save_safetensors True \
--save_only_model True \
--seed 42 \
--manual_lora_rank 32 \
--manual_lora_alpha 32 \
--manual_use_dora True \
--manual_bos_token "<|begin_of_text|>" \
--manual_eos_token "<|end_of_text|>" \
--manual_pad_token "<|end_of_text|>" \
--manual_drop_long_samples True \
--manual_attn_implementation "flash_attention_2" \
--manual_prompt_format "llama_3_instruct"
"""
import torch
import multiprocessing
from dataclasses import dataclass, field
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from peft import LoraConfig
from trl import ModelConfig, ORPOConfig, ORPOTrainer
from tqdm import tqdm
import transformers.modeling_utils
@dataclass
class ScriptArguments:
dataset: str = field(
default="PJMixers/Intel_orca_dpo_pairs-PreferenceShareGPT",
metadata={"help": "The name of the dataset to use."},
)
manual_lora_rank: int = field(
default=32,
metadata={"help": "LoRA Rank."}
)
manual_lora_alpha: int = field(
default=32,
metadata={"help": "LoRA Alpha."}
)
manual_use_dora: bool = field(
default=True,
metadata={"help": "Enable DoRA."}
)
manual_bos_token: str = field(
default=""
)
manual_eos_token: str = field(
default=""
)
manual_pad_token: str = field(
default="<|end_of_text|>"
)
manual_drop_long_samples: bool = field(
default=True,
metadata={"help": "Choose to drop samples that are too long, or to truncate."}
)
manual_attn_implementation: str = field(
default="flash_attention_2"
)
manual_prompt_format: str = field(
default="llama_3_instruct"
)
# https://github.com/axolotl-ai-cloud/axolotl/pull/1528
class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non-blocking calls.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (
None,
hidden_states.grad,
) + (
None,
) * len(ctx.args)
# https://github.com/axolotl-ai-cloud/axolotl/pull/1528
def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer,
*args,
use_reentrant=None
):
return UnslothOffloadedGradientCheckpointer.apply(
decoder_layer.__self__,
*args,
)
def llama_3_instruct_prompt_format(
sample,
do_check=False
):
(
formatted_sample,
prompt,
chosen,
rejected
) = (
{},
tokenizer.bos_token if args.manual_bos_token == "" else args.manual_bos_token,
None,
None
)
if prompt is None:
prompt = ""
for turn in sample["conversations"]:
sharegpt_from, sharegpt_value = turn["from"].strip(), turn["value"].strip()
if sharegpt_from == "system":
role_name = "system"
elif sharegpt_from == "human":
role_name = "user"
elif sharegpt_from == "human-chat":
role_name = "user"
sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}"
elif sharegpt_from == "gpt":
role_name = "assistant"
elif sharegpt_from == "gpt-chat":
role_name = "assistant"
sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}"
else:
print(f"'from' contains an unhandled string")
exit()
prompt += (
f"<|start_header_id|>{role_name}<|end_header_id|>\n\n"
f"{sharegpt_value}<|eot_id|>"
)
formatted_sample["prompt"] = (
f"{prompt}"
f"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
formatted_sample["chosen"] = (
f"{sample['chosen_gpt'].strip()}<|eot_id|>"
f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}"
)
formatted_sample["rejected"] = (
f"{sample['rejected_gpt'].strip()}<|eot_id|>"
f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}"
)
# Used for filtering out samples which are too long before training
if do_check is True:
prompt_length = len(tokenizer.encode(formatted_sample["prompt"]))
chosen_length = len(tokenizer.encode(formatted_sample["chosen"]))
rejected_length = len(tokenizer.encode(formatted_sample["rejected"]))
return prompt_length + max(chosen_length, rejected_length)
return formatted_sample
def chatml_prompt_format(
sample,
do_check=False
):
(
formatted_sample,
prompt,
chosen,
rejected
) = (
{},
tokenizer.bos_token if args.manual_bos_token == "" else args.manual_bos_token,
None,
None
)
if prompt is None:
prompt = ""
for turn in sample["conversations"]:
sharegpt_from, sharegpt_value = turn["from"].strip(), turn["value"].strip()
if sharegpt_from == "system":
role_name = "system"
elif sharegpt_from == "human":
role_name = "user"
elif sharegpt_from == "human-chat":
role_name = "user"
sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}"
elif sharegpt_from == "gpt":
role_name = "assistant"
elif sharegpt_from == "gpt-chat":
role_name = "assistant"
sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}"
else:
print(f"'from' contains an unhandled string")
exit()
prompt += (
f"<|im_start|>{role_name}\n"
f"{sharegpt_value}<|im_end|>\n"
)
formatted_sample["prompt"] = (
f"{prompt}"
f"<|im_start|>assistant\n"
)
formatted_sample["chosen"] = (
f"{sample['chosen_gpt'].strip()}<|im_end|>"
f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}"
)
formatted_sample["rejected"] = (
f"{sample['rejected_gpt'].strip()}<|im_end|>"
f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}"
)
# Used for filtering out samples which are too long before training
if do_check is True:
prompt_length = len(tokenizer.encode(formatted_sample["prompt"]))
chosen_length = len(tokenizer.encode(formatted_sample["chosen"]))
rejected_length = len(tokenizer.encode(formatted_sample["rejected"]))
return prompt_length + max(chosen_length, rejected_length)
return formatted_sample
def fizzpaca_prompt_format(
sample,
do_check=False
):
(
formatted_sample,
prompt,
chosen,
rejected
) = (
{},
tokenizer.bos_token if args.manual_bos_token == "" else args.manual_bos_token,
None,
None
)
if prompt is None:
prompt = ""
for turn in sample["conversations"]:
sharegpt_from, sharegpt_value = turn["from"].strip(), turn["value"].strip()
if sharegpt_from == "system":
role_name = "### System:\n"
elif sharegpt_from == "human":
role_name = "### Instruction:\n"
elif sharegpt_from == "human-chat":
role_name = "### Instruction:\n"
sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}"
elif sharegpt_from == "gpt":
role_name = "### Response:\n"
sharegpt_value = (
f"{sharegpt_value}"
f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}"
)
elif sharegpt_from == "gpt-chat":
role_name = "### Response:\n"
sharegpt_value = (
f"{turn['name'].strip()}: {sharegpt_value}"
f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}"
)
else:
print(f"'from' contains an unhandled string")
exit()
prompt += (
f"{role_name}"
f"{sharegpt_value}\n\n"
)
formatted_sample["prompt"] = (
f"{prompt}"
f"### Response:\n"
)
formatted_sample["chosen"] = (
f"{sample['chosen_gpt'].strip()}"
f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}"
)
formatted_sample["rejected"] = (
f"{sample['rejected_gpt'].strip()}"
f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}"
)
# Used for filtering out samples which are too long before training
if do_check is True:
prompt_length = len(tokenizer.encode(formatted_sample["prompt"]))
chosen_length = len(tokenizer.encode(formatted_sample["chosen"]))
rejected_length = len(tokenizer.encode(formatted_sample["rejected"]))
return prompt_length + max(chosen_length, rejected_length)
return formatted_sample
if __name__ == "__main__":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig))
args, orpo_args, model_config = parser.parse_args_into_dataclasses()
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = args.manual_pad_token
# Tokenize dataset
ds = load_dataset(args.dataset)
if args.manual_prompt_format == "llama_3_instruct":
if args.manual_drop_long_samples:
ds = ds.filter(
lambda example: llama_3_instruct_prompt_format(sample=example, do_check=True) <= orpo_args.max_length,
load_from_cache_file=True
)
ds = ds.map(
llama_3_instruct_prompt_format,
num_proc=multiprocessing.cpu_count(),
load_from_cache_file=True,
)
ds = ds.shuffle(seed=orpo_args.seed)
elif args.manual_prompt_format == "chatml":
if args.manual_drop_long_samples:
ds = ds.filter(
lambda example: chatml_prompt_format(sample=example, do_check=True) <= orpo_args.max_length,
load_from_cache_file=True
)
ds = ds.map(
chatml_prompt_format,
num_proc=multiprocessing.cpu_count(),
load_from_cache_file=True,
)
ds = ds.shuffle(seed=orpo_args.seed)
elif args.manual_prompt_format == "fizzpaca":
if args.manual_drop_long_samples:
ds = ds.filter(
lambda example: fizzpaca_prompt_format(sample=example, do_check=True) <= orpo_args.max_length,
load_from_cache_file=True
)
ds = ds.map(
fizzpaca_prompt_format,
num_proc=multiprocessing.cpu_count(),
load_from_cache_file=True,
)
ds = ds.shuffle(seed=orpo_args.seed)
else:
print("Invalid 'manual_prompt_format'.")
exit()
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
attn_implementation=args.manual_attn_implementation,
quantization_config={ # Took these settings from what my Axolotl QLoRAs used
"load_in_4bit": True,
"load_in_8bit": False,
"bnb_4bit_compute_dtype": "bfloat16",
"bnb_4bit_quant_storage": "bfloat16",
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": True,
"llm_int8_enable_fp32_cpu_offload": False,
"llm_int8_has_fp16_weight": False,
"llm_int8_threshold": 6.0
}
)
if model.config.pad_token_id is None:
model.config.pad_token_id = tokenizer.pad_token_id
# Load PEFT
peft_config = LoraConfig(
lora_alpha=args.manual_lora_rank,
r=args.manual_lora_alpha,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"gate_proj",
"down_proj",
"up_proj"
"q_proj",
"v_proj",
"k_proj",
"o_proj",
],
use_dora=args.manual_use_dora
)
# Start training
trainer = ORPOTrainer(
model,
args=orpo_args,
train_dataset=ds["train"],
tokenizer=tokenizer,
peft_config=peft_config,
)
# Train and save the model
trainer.train()
trainer.save_model(orpo_args.output_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment