Instantly share code, notes, and snippets.
Last active
July 29, 2024 12:30
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
-
Save xzuyn/87097972ab2323ced81e4d7b41c47a45 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Modified from this original script: | |
| # https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/examples/scripts/orpo.py | |
| # | |
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Run the ORPO training script with the following command with some example arguments. | |
| In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model: | |
| python ORPO_PreferenceShareGPT.py \ | |
| --model_name_or_path "meta-llama/Meta-Llama-3-8B-Instruct" \ | |
| --dataset "PJMixers/Intel_orca_dpo_pairs-PreferenceShareGPT" \ | |
| --output_dir "./LLaMa-3-Instruct-Intel-Orca-ORPO-8B-QDoRA" \ | |
| --report_to wandb \ | |
| --run_name "LLaMa-3-Instruct-Intel-Orca-ORPO-8B-QDoRA" \ | |
| --push_to_hub True \ | |
| --hub_strategy "all_checkpoints" \ | |
| --hub_private_repo True \ | |
| --hub_model_id "xzuyn/LLaMa-3-Instruct-Intel-Orca-ORPO-8B-QDoRA" \ | |
| --num_train_epochs 1 \ | |
| --max_length 2048 \ | |
| --max_prompt_length 2048 \ | |
| --per_device_train_batch_size 1 \ | |
| --gradient_accumulation_steps 8 \ | |
| --gradient_checkpointing True \ | |
| --learning_rate 0.00001 \ | |
| --lr_scheduler_type "cosine" \ | |
| --beta 0.1 \ | |
| --weight_decay 0.1 \ | |
| --max_grad_norm 1 \ | |
| --logging_steps 1 \ | |
| --warmup_steps 100 \ | |
| --save_strategy "steps" \ | |
| --save_steps 100 \ | |
| --optim paged_adamw_8bit \ | |
| --bf16 \ | |
| --logging_first_step \ | |
| --no_remove_unused_columns \ | |
| --save_total_limit 2 \ | |
| --save_safetensors True \ | |
| --save_only_model True \ | |
| --seed 42 \ | |
| --manual_lora_rank 32 \ | |
| --manual_lora_alpha 32 \ | |
| --manual_use_dora True \ | |
| --manual_bos_token "<|begin_of_text|>" \ | |
| --manual_eos_token "<|end_of_text|>" \ | |
| --manual_pad_token "<|end_of_text|>" \ | |
| --manual_drop_long_samples True \ | |
| --manual_attn_implementation "flash_attention_2" \ | |
| --manual_prompt_format "llama_3_instruct" | |
| """ | |
| import torch | |
| import multiprocessing | |
| from dataclasses import dataclass, field | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser | |
| from peft import LoraConfig | |
| from trl import ModelConfig, ORPOConfig, ORPOTrainer | |
| from tqdm import tqdm | |
| import transformers.modeling_utils | |
| @dataclass | |
| class ScriptArguments: | |
| dataset: str = field( | |
| default="PJMixers/Intel_orca_dpo_pairs-PreferenceShareGPT", | |
| metadata={"help": "The name of the dataset to use."}, | |
| ) | |
| manual_lora_rank: int = field( | |
| default=32, | |
| metadata={"help": "LoRA Rank."} | |
| ) | |
| manual_lora_alpha: int = field( | |
| default=32, | |
| metadata={"help": "LoRA Alpha."} | |
| ) | |
| manual_use_dora: bool = field( | |
| default=True, | |
| metadata={"help": "Enable DoRA."} | |
| ) | |
| manual_bos_token: str = field( | |
| default="" | |
| ) | |
| manual_eos_token: str = field( | |
| default="" | |
| ) | |
| manual_pad_token: str = field( | |
| default="<|end_of_text|>" | |
| ) | |
| manual_drop_long_samples: bool = field( | |
| default=True, | |
| metadata={"help": "Choose to drop samples that are too long, or to truncate."} | |
| ) | |
| manual_attn_implementation: str = field( | |
| default="flash_attention_2" | |
| ) | |
| manual_prompt_format: str = field( | |
| default="llama_3_instruct" | |
| ) | |
| # https://github.com/axolotl-ai-cloud/axolotl/pull/1528 | |
| class UnslothOffloadedGradientCheckpointer(torch.autograd.Function): | |
| """ | |
| Saves VRAM by smartly offloading to RAM. | |
| Tiny hit to performance, since we mask the movement via non-blocking calls. | |
| """ | |
| @staticmethod | |
| @torch.cuda.amp.custom_fwd | |
| def forward(ctx, forward_function, hidden_states, *args): | |
| saved_hidden_states = hidden_states.to("cpu", non_blocking=True) | |
| with torch.no_grad(): | |
| output = forward_function(hidden_states, *args) | |
| ctx.save_for_backward(saved_hidden_states) | |
| ctx.forward_function = forward_function | |
| ctx.args = args | |
| return output | |
| @staticmethod | |
| @torch.cuda.amp.custom_bwd | |
| def backward(ctx, dY): | |
| (hidden_states,) = ctx.saved_tensors | |
| hidden_states = hidden_states.to("cuda", non_blocking=True).detach() | |
| hidden_states.requires_grad = True | |
| with torch.enable_grad(): | |
| (output,) = ctx.forward_function(hidden_states, *ctx.args) | |
| torch.autograd.backward(output, dY) | |
| return ( | |
| None, | |
| hidden_states.grad, | |
| ) + ( | |
| None, | |
| ) * len(ctx.args) | |
| # https://github.com/axolotl-ai-cloud/axolotl/pull/1528 | |
| def hf_grad_checkpoint_unsloth_wrapper( | |
| decoder_layer, | |
| *args, | |
| use_reentrant=None | |
| ): | |
| return UnslothOffloadedGradientCheckpointer.apply( | |
| decoder_layer.__self__, | |
| *args, | |
| ) | |
| def llama_3_instruct_prompt_format( | |
| sample, | |
| do_check=False | |
| ): | |
| ( | |
| formatted_sample, | |
| prompt, | |
| chosen, | |
| rejected | |
| ) = ( | |
| {}, | |
| tokenizer.bos_token if args.manual_bos_token == "" else args.manual_bos_token, | |
| None, | |
| None | |
| ) | |
| if prompt is None: | |
| prompt = "" | |
| for turn in sample["conversations"]: | |
| sharegpt_from, sharegpt_value = turn["from"].strip(), turn["value"].strip() | |
| if sharegpt_from == "system": | |
| role_name = "system" | |
| elif sharegpt_from == "human": | |
| role_name = "user" | |
| elif sharegpt_from == "human-chat": | |
| role_name = "user" | |
| sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}" | |
| elif sharegpt_from == "gpt": | |
| role_name = "assistant" | |
| elif sharegpt_from == "gpt-chat": | |
| role_name = "assistant" | |
| sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}" | |
| else: | |
| print(f"'from' contains an unhandled string") | |
| exit() | |
| prompt += ( | |
| f"<|start_header_id|>{role_name}<|end_header_id|>\n\n" | |
| f"{sharegpt_value}<|eot_id|>" | |
| ) | |
| formatted_sample["prompt"] = ( | |
| f"{prompt}" | |
| f"<|start_header_id|>assistant<|end_header_id|>\n\n" | |
| ) | |
| formatted_sample["chosen"] = ( | |
| f"{sample['chosen_gpt'].strip()}<|eot_id|>" | |
| f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}" | |
| ) | |
| formatted_sample["rejected"] = ( | |
| f"{sample['rejected_gpt'].strip()}<|eot_id|>" | |
| f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}" | |
| ) | |
| # Used for filtering out samples which are too long before training | |
| if do_check is True: | |
| prompt_length = len(tokenizer.encode(formatted_sample["prompt"])) | |
| chosen_length = len(tokenizer.encode(formatted_sample["chosen"])) | |
| rejected_length = len(tokenizer.encode(formatted_sample["rejected"])) | |
| return prompt_length + max(chosen_length, rejected_length) | |
| return formatted_sample | |
| def chatml_prompt_format( | |
| sample, | |
| do_check=False | |
| ): | |
| ( | |
| formatted_sample, | |
| prompt, | |
| chosen, | |
| rejected | |
| ) = ( | |
| {}, | |
| tokenizer.bos_token if args.manual_bos_token == "" else args.manual_bos_token, | |
| None, | |
| None | |
| ) | |
| if prompt is None: | |
| prompt = "" | |
| for turn in sample["conversations"]: | |
| sharegpt_from, sharegpt_value = turn["from"].strip(), turn["value"].strip() | |
| if sharegpt_from == "system": | |
| role_name = "system" | |
| elif sharegpt_from == "human": | |
| role_name = "user" | |
| elif sharegpt_from == "human-chat": | |
| role_name = "user" | |
| sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}" | |
| elif sharegpt_from == "gpt": | |
| role_name = "assistant" | |
| elif sharegpt_from == "gpt-chat": | |
| role_name = "assistant" | |
| sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}" | |
| else: | |
| print(f"'from' contains an unhandled string") | |
| exit() | |
| prompt += ( | |
| f"<|im_start|>{role_name}\n" | |
| f"{sharegpt_value}<|im_end|>\n" | |
| ) | |
| formatted_sample["prompt"] = ( | |
| f"{prompt}" | |
| f"<|im_start|>assistant\n" | |
| ) | |
| formatted_sample["chosen"] = ( | |
| f"{sample['chosen_gpt'].strip()}<|im_end|>" | |
| f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}" | |
| ) | |
| formatted_sample["rejected"] = ( | |
| f"{sample['rejected_gpt'].strip()}<|im_end|>" | |
| f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}" | |
| ) | |
| # Used for filtering out samples which are too long before training | |
| if do_check is True: | |
| prompt_length = len(tokenizer.encode(formatted_sample["prompt"])) | |
| chosen_length = len(tokenizer.encode(formatted_sample["chosen"])) | |
| rejected_length = len(tokenizer.encode(formatted_sample["rejected"])) | |
| return prompt_length + max(chosen_length, rejected_length) | |
| return formatted_sample | |
| def fizzpaca_prompt_format( | |
| sample, | |
| do_check=False | |
| ): | |
| ( | |
| formatted_sample, | |
| prompt, | |
| chosen, | |
| rejected | |
| ) = ( | |
| {}, | |
| tokenizer.bos_token if args.manual_bos_token == "" else args.manual_bos_token, | |
| None, | |
| None | |
| ) | |
| if prompt is None: | |
| prompt = "" | |
| for turn in sample["conversations"]: | |
| sharegpt_from, sharegpt_value = turn["from"].strip(), turn["value"].strip() | |
| if sharegpt_from == "system": | |
| role_name = "### System:\n" | |
| elif sharegpt_from == "human": | |
| role_name = "### Instruction:\n" | |
| elif sharegpt_from == "human-chat": | |
| role_name = "### Instruction:\n" | |
| sharegpt_value = f"{turn['name'].strip()}: {sharegpt_value}" | |
| elif sharegpt_from == "gpt": | |
| role_name = "### Response:\n" | |
| sharegpt_value = ( | |
| f"{sharegpt_value}" | |
| f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}" | |
| ) | |
| elif sharegpt_from == "gpt-chat": | |
| role_name = "### Response:\n" | |
| sharegpt_value = ( | |
| f"{turn['name'].strip()}: {sharegpt_value}" | |
| f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}" | |
| ) | |
| else: | |
| print(f"'from' contains an unhandled string") | |
| exit() | |
| prompt += ( | |
| f"{role_name}" | |
| f"{sharegpt_value}\n\n" | |
| ) | |
| formatted_sample["prompt"] = ( | |
| f"{prompt}" | |
| f"### Response:\n" | |
| ) | |
| formatted_sample["chosen"] = ( | |
| f"{sample['chosen_gpt'].strip()}" | |
| f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}" | |
| ) | |
| formatted_sample["rejected"] = ( | |
| f"{sample['rejected_gpt'].strip()}" | |
| f"{tokenizer.eos_token if args.manual_eos_token == '' else args.manual_eos_token}" | |
| ) | |
| # Used for filtering out samples which are too long before training | |
| if do_check is True: | |
| prompt_length = len(tokenizer.encode(formatted_sample["prompt"])) | |
| chosen_length = len(tokenizer.encode(formatted_sample["chosen"])) | |
| rejected_length = len(tokenizer.encode(formatted_sample["rejected"])) | |
| return prompt_length + max(chosen_length, rejected_length) | |
| return formatted_sample | |
| if __name__ == "__main__": | |
| transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper | |
| parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig)) | |
| args, orpo_args, model_config = parser.parse_args_into_dataclasses() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_config.model_name_or_path, | |
| trust_remote_code=model_config.trust_remote_code | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = args.manual_pad_token | |
| # Tokenize dataset | |
| ds = load_dataset(args.dataset) | |
| if args.manual_prompt_format == "llama_3_instruct": | |
| if args.manual_drop_long_samples: | |
| ds = ds.filter( | |
| lambda example: llama_3_instruct_prompt_format(sample=example, do_check=True) <= orpo_args.max_length, | |
| load_from_cache_file=True | |
| ) | |
| ds = ds.map( | |
| llama_3_instruct_prompt_format, | |
| num_proc=multiprocessing.cpu_count(), | |
| load_from_cache_file=True, | |
| ) | |
| ds = ds.shuffle(seed=orpo_args.seed) | |
| elif args.manual_prompt_format == "chatml": | |
| if args.manual_drop_long_samples: | |
| ds = ds.filter( | |
| lambda example: chatml_prompt_format(sample=example, do_check=True) <= orpo_args.max_length, | |
| load_from_cache_file=True | |
| ) | |
| ds = ds.map( | |
| chatml_prompt_format, | |
| num_proc=multiprocessing.cpu_count(), | |
| load_from_cache_file=True, | |
| ) | |
| ds = ds.shuffle(seed=orpo_args.seed) | |
| elif args.manual_prompt_format == "fizzpaca": | |
| if args.manual_drop_long_samples: | |
| ds = ds.filter( | |
| lambda example: fizzpaca_prompt_format(sample=example, do_check=True) <= orpo_args.max_length, | |
| load_from_cache_file=True | |
| ) | |
| ds = ds.map( | |
| fizzpaca_prompt_format, | |
| num_proc=multiprocessing.cpu_count(), | |
| load_from_cache_file=True, | |
| ) | |
| ds = ds.shuffle(seed=orpo_args.seed) | |
| else: | |
| print("Invalid 'manual_prompt_format'.") | |
| exit() | |
| # Load model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_config.model_name_or_path, | |
| trust_remote_code=model_config.trust_remote_code, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| attn_implementation=args.manual_attn_implementation, | |
| quantization_config={ # Took these settings from what my Axolotl QLoRAs used | |
| "load_in_4bit": True, | |
| "load_in_8bit": False, | |
| "bnb_4bit_compute_dtype": "bfloat16", | |
| "bnb_4bit_quant_storage": "bfloat16", | |
| "bnb_4bit_quant_type": "nf4", | |
| "bnb_4bit_use_double_quant": True, | |
| "llm_int8_enable_fp32_cpu_offload": False, | |
| "llm_int8_has_fp16_weight": False, | |
| "llm_int8_threshold": 6.0 | |
| } | |
| ) | |
| if model.config.pad_token_id is None: | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| # Load PEFT | |
| peft_config = LoraConfig( | |
| lora_alpha=args.manual_lora_rank, | |
| r=args.manual_lora_alpha, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=[ | |
| "gate_proj", | |
| "down_proj", | |
| "up_proj" | |
| "q_proj", | |
| "v_proj", | |
| "k_proj", | |
| "o_proj", | |
| ], | |
| use_dora=args.manual_use_dora | |
| ) | |
| # Start training | |
| trainer = ORPOTrainer( | |
| model, | |
| args=orpo_args, | |
| train_dataset=ds["train"], | |
| tokenizer=tokenizer, | |
| peft_config=peft_config, | |
| ) | |
| # Train and save the model | |
| trainer.train() | |
| trainer.save_model(orpo_args.output_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment