Skip to content

Instantly share code, notes, and snippets.

@ericflo
Last active October 10, 2023 03:21
Show Gist options
  • Save ericflo/cb1687229d00eb6d45cc7ea3354363a4 to your computer and use it in GitHub Desktop.
Save ericflo/cb1687229d00eb6d45cc7ea3354363a4 to your computer and use it in GitHub Desktop.
CodeLlama / CausalLM LoRA Training Code
transformers
tqdm
accelerate
datasets
peft
scipy
bitsandbytes
wandb
nvitop
flash-attn
import json
import random
import os
from dataclasses import dataclass, field
import torch
from datasets import Dataset
from transformers import (
BitsAndBytesConfig,
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
HfArgumentParser,
TrainingArguments,
DataCollatorForLanguageModeling,
set_seed,
)
from peft import (
get_peft_model,
LoraConfig,
TaskType,
prepare_model_for_kbit_training,
)
@dataclass
class CustomArguments:
model_name: str = field(
default="codellama/CodeLlama-7b-hf",
metadata={"help": "The name of the model to train."},
)
dataset_filepath: str = field(
default="all.jsonl",
metadata={"help": "The path to the dataset to use for training."},
)
trained_model_filepath: str = field(
default="trained",
metadata={"help": "The path to the trained model."},
)
test_split: float = field(
default=0.1,
metadata={"help": "The percentage of the dataset to use for testing."},
)
random_seed: int = field(
default=42,
metadata={"help": "The seed to use to control randomness."},
)
lora_rank: int = field(
default=8,
metadata={"help": "The rank of the LoRA model."},
)
lora_alpha: int = field(
default=32,
metadata={"help": "The alpha value of the LoRA model."},
)
lora_dropout: float = field(
default=0.1,
metadata={"help": "The dropout rate of the LoRA model."},
)
tokenize_batch_size: int = field(
default=128,
metadata={"help": "The batch size to use for tokenizing the dataset."},
)
max_length: int = field(
default=4096,
metadata={"help": "The maximum length of the input sequence."},
)
use_flash_attention: bool = field(
default=False,
metadata={"help": "Whether to use flash attention."},
)
quantize_4bit: bool = field(
default=False,
metadata={"help": "Whether to quantize the model to 4 bits."},
)
quantize_8bit: bool = field(
default=False,
metadata={"help": "Whether to quantize the model to 8 bits."},
)
trainer_resume_from_checkpoint: str = field(
default="",
metadata={"help": "Path to a checkpoint to resume from."},
)
def _labeled_tokenize(text_batch, tokenizer, max_length):
examples = tokenizer(
text_batch,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="pt",
)
examples["labels"] = examples["input_ids"].clone()
return examples
def process_dataset(dataset, tokenizer, max_length, batch_size):
return Dataset.from_list([{"text": t} for t in dataset]).map(
lambda batch: _labeled_tokenize(batch["text"], tokenizer, max_length),
batched=True,
batch_size=batch_size,
)
def main():
parser = HfArgumentParser((TrainingArguments, CustomArguments))
training_args, custom_args = parser.parse_args_into_dataclasses()
set_seed(custom_args.random_seed)
# Load the dataset
with open(custom_args.dataset_filepath) as f:
dataset = [json.loads(line)["text"] for line in f.readlines() if line.strip()]
random.shuffle(dataset)
# Split the dataset
test_split_index = int(len(dataset) * custom_args.test_split)
train_dataset, test_dataset = dataset[test_split_index:], dataset[:test_split_index]
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(custom_args.model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
# Process the dataset
train_dataset = process_dataset(
train_dataset,
tokenizer,
custom_args.max_length,
custom_args.tokenize_batch_size,
)
test_dataset = process_dataset(
test_dataset, tokenizer, custom_args.max_length, custom_args.tokenize_batch_size
)
# Load the model
model = AutoModelForCausalLM.from_pretrained(
custom_args.model_name,
torch_dtype=torch.bfloat16,
use_cache=False,
use_flash_attention_2=custom_args.use_flash_attention,
device_map=f'cuda:{os.environ.get("LOCAL_RANK", "cuda")}',
load_in_8bit=custom_args.quantize_8bit,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
if custom_args.quantize_4bit
else None,
)
model.config.pretraining_tp = 1
model.enable_input_require_grads()
peft_model = model
if custom_args.quantize_4bit or custom_args.quantize_8bit:
peft_model = prepare_model_for_kbit_training(peft_model, True)
peft_model = get_peft_model(
peft_model,
LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=custom_args.lora_rank,
lora_alpha=custom_args.lora_alpha,
lora_dropout=custom_args.lora_dropout,
bias="none",
target_modules=["gate_proj", "down_proj", "up_proj"],
),
)
peft_model.print_trainable_parameters()
# Train the model
trainer = Trainer(
model=peft_model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
)
trainer.train(
resume_from_checkpoint=custom_args.trainer_resume_from_checkpoint or None
)
# Save the model
trained_path = os.path.join(os.getcwd(), custom_args.trained_model_filepath)
os.makedirs(trained_path, exist_ok=True)
peft_model.save_pretrained(trained_path)
tokenizer.save_pretrained(trained_path)
if __name__ == "__main__":
main()
#!/usr/bin/env bash
accelerate launch \
--multi_gpu \
--num_processes 2 \
train.py \
--optim adamw_bnb_8bit \
--evaluation_strategy steps \
--dataset_filepath all.jsonl \
--output_dir training \
--save_total_limit 3 \
--trained_model_filepath trained \
--gradient_checkpointing True \
--ddp_find_unused_parameters False \
--max_grad_norm 0.3 \
--learning_rate 5e-4 \
--test_split 0.05 \
--lora_rank 8 \
--lora_alpha 256 \
--num_train_epochs 1 \
--logging_steps 1 \
--evaluation_strategy steps \
--eval_steps 20 \
--per_device_train_batch_size 5 \
--per_device_eval_batch_size 5 \
--tokenize_batch_size 1024 \
--gradient_accumulation_steps 1 \
--max_length 2048 \
--use_flash_attention True \
--bf16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment