Created
October 15, 2023 23:06
-
-
Save rohan-paul/8ad2be0887f2447d3f107d60c72e74f4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from datasets import load_dataset | |
from peft import LoraConfig | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
HfArgumentParser, | |
TrainingArguments, | |
) | |
from peft.tuners.lora import LoraLayer | |
from trl import SFTTrainer | |
from dataclasses import dataclass, field | |
from typing import Optional | |
@dataclass | |
class ModelArguments: | |
""" | |
Arguments for creating and preparing the model. | |
""" | |
model_name: str = field( | |
default="tiiuae/falcon-7b", | |
metadata={"help": "The model name or path from the Hugging Face hub."}, | |
) | |
use_4bit: bool = field( | |
default=True, | |
metadata={"help": "Activate 4bit precision base model loading"}, | |
) | |
use_nested_quant: bool = field( | |
default=False, | |
metadata={"help": "Activate nested quantization for 4bit base models"}, | |
) | |
bnb_4bit_compute_dtype: str = field( | |
default="float16", | |
metadata={"help": "Compute dtype for 4bit base models"}, | |
) | |
bnb_4bit_quant_type: str = field( | |
default="nf4", | |
metadata={"help": "Quantization type: fp4 or nf4"}, | |
) | |
lora_alpha: int = field(default=16) | |
lora_dropout: float = field(default=0.1) | |
lora_r: int = field(default=64) | |
@dataclass | |
class ScriptArguments: | |
""" | |
Arguments for model training and data handling. | |
""" | |
local_rank: int = field(default=-1, metadata={"help": "Used for multi-gpu"}) | |
per_device_train_batch_size: int = field(default=4) | |
per_device_eval_batch_size: Optional[int] = field(default=1) | |
gradient_accumulation_steps: Optional[int] = field(default=4) | |
learning_rate: Optional[float] = field(default=2e-4) | |
max_grad_norm: Optional[float] = field(default=0.3) | |
weight_decay: Optional[int] = field(default=0.001) | |
max_seq_length: Optional[int] = field(default=512) | |
dataset_name: Optional[str] = field( | |
default="timdettmers/openassistant-guanaco", | |
metadata={"help": "The preference dataset to use."}, | |
) | |
num_train_epochs: Optional[int] = field( | |
default=1, | |
metadata={"help": "The number of training epochs for the reward model."}, | |
) | |
fp16: Optional[bool] = field( | |
default=False, | |
metadata={"help": "Enables fp16 training."}, | |
) | |
bf16: Optional[bool] = field( | |
default=False, | |
metadata={"help": "Enables bf16 training."}, | |
) | |
packing: Optional[bool] = field( | |
default=False, | |
metadata={"help": "Use packing dataset creating."}, | |
) | |
gradient_checkpointing: Optional[bool] = field( | |
default=True, | |
metadata={"help": "Enables gradient checkpointing."}, | |
) | |
optim: Optional[str] = field( | |
default="paged_adamw_32bit", | |
metadata={"help": "The optimizer to use."}, | |
) | |
lr_scheduler_type: str = field( | |
default="constant", | |
metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"}, | |
) | |
max_steps: int = field(default=10000, metadata={"help": "How many optimizer update steps to take"}) | |
warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"}) | |
group_by_length: bool = field( | |
default=True, | |
metadata={ | |
"help": "Group sequences into batches with same length. Saves memory and speeds up training considerably." | |
}, | |
) | |
save_steps: int = field(default=10, metadata={"help": "Save checkpoint every X updates steps."}) | |
logging_steps: int = field(default=10, metadata={"help": "Log every X updates steps."}) | |
def get_model_peftconfig_tokenizer(args: ModelArguments): | |
""" | |
Create the model, tokenizer, and peft_config based on provided arguments. | |
""" | |
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype) | |
# Configure BitsAndBytes for model quantization | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=args.use_4bit, | |
bnb_4bit_quant_type=args.bnb_4bit_quant_type, | |
bnb_4bit_compute_dtype=compute_dtype, | |
bnb_4bit_use_double_quant=args.use_nested_quant, | |
) | |
# Alert for bfloat16 acceleration support | |
if compute_dtype == torch.float16 and args.use_4bit: | |
major, _ = torch.cuda.get_device_capability() | |
if major >= 8: | |
print("=" * 80) | |
print("Your GPU supports bfloat16, you can accelerate training with --bf16") | |
print("=" * 80) | |
# Load the model with quantization configuration | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name, quantization_config=bnb_config, device_map={"": 0}, trust_remote_code=True | |
) | |
# Define Lora Configuration | |
peft_config = LoraConfig( | |
lora_alpha=args.lora_alpha, | |
lora_dropout=args.lora_dropout, | |
r=args.lora_r, | |
bias="none", | |
task_type="CAUSAL_LM", | |
target_modules=[ | |
"query_key_value", | |
"dense", | |
"dense_h_to_4h", | |
"dense_4h_to_h", | |
], | |
) | |
# Load the tokenizer and set padding token | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) | |
# Need to do below for models like Falcon-7B, GPT-2 etc, | |
# because it doesn't have an official pad token. | |
tokenizer.pad_token = tokenizer.eos_token | |
return model, peft_config, tokenizer | |
def parse_arguments(): | |
""" | |
Parse Model and Script Arguments. | |
Returns: | |
ModelArguments, ScriptArguments | |
""" | |
parser = HfArgumentParser((ModelArguments, ScriptArguments)) | |
return parser.parse_args_into_dataclasses() | |
def load_training_data(dataset_name: str): | |
""" | |
Load dataset for training. | |
Args: | |
dataset_name (str): Name or path of the dataset. | |
Returns: | |
Dataset object | |
""" | |
return load_dataset(dataset_name, split="train") | |
def get_training_args(script_args: ScriptArguments): | |
""" | |
Get Training Arguments from ScriptArguments. | |
Args: | |
script_args (ScriptArguments): Parsed ScriptArguments. | |
Returns: | |
TrainingArguments | |
""" | |
return TrainingArguments( | |
output_dir="./results", | |
per_device_train_batch_size=script_args.per_device_train_batch_size, | |
# ... (rest of your args from script_args) | |
) | |
def adjust_model_for_bf16(trainer, bf16: bool): | |
""" | |
Adjust Model Layers for bf16. | |
Args: | |
trainer (SFTTrainer): Initialized SFTTrainer object. | |
bf16 (bool): Flag to indicate usage of bf16. | |
""" | |
for name, module in trainer.model.named_modules(): | |
if isinstance(module, LoraLayer) and bf16: | |
module = module.to(torch.bfloat16) | |
if "norm" in name: | |
module = module.to(torch.float32) | |
if "lm_head" in name or "embed_tokens" in name: | |
if hasattr(module, "weight") and bf16 and module.weight.dtype == torch.float32: | |
module = module.to(torch.bfloat16) | |
# Main Execution: | |
model_args, script_args = parse_arguments() | |
model, peft_config, tokenizer = get_model_peftconfig_tokenizer(model_args) | |
model.config.use_cache = False | |
dataset = load_training_data(script_args.dataset_name) | |
training_arguments = get_training_args(script_args) | |
trainer = SFTTrainer( | |
model=model, | |
train_dataset=dataset, | |
peft_config=peft_config, | |
dataset_text_field="text", | |
max_seq_length=script_args.max_seq_length, | |
tokenizer=tokenizer, | |
args=training_arguments, | |
packing=script_args.packing, | |
) | |
adjust_model_for_bf16(trainer, script_args.bf16) | |
# Train the Model | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment