Skip to content

Instantly share code, notes, and snippets.

@h-lunah
Last active January 18, 2025 14:37
Show Gist options
  • Save h-lunah/2c0950d48a49469be128eaa4d2c50ded to your computer and use it in GitHub Desktop.
Save h-lunah/2c0950d48a49469be128eaa4d2c50ded to your computer and use it in GitHub Desktop.
Distributed Hugging Face Training
import os
import json
import random
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
BitsAndBytesConfig,
)
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
from torch.distributed import (
init_process_group,
destroy_process_group,
is_available as is_dist_available,
)
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel
from huggingface_hub import HfApi, create_repo
def setup_distributed():
# Only run distributed training if multiple GPUs are detected
if is_dist_available() and torch.cuda.device_count() > 1:
init_process_group(backend="nccl")
local_rank = os.environ.get("LOCAL_RANK")
torch.cuda.set_device(f"cuda:{local_rank}")
print(f"Distributed training has been enabled. Local rank: {local_rank}")
return local_rank
return 0
def display_message(accelerator, message):
# Only display messages in the main process
if accelerator.is_main_process:
print(message)
def check_gpus(accelerator):
# Map compute capability to GPU family
def get_gpu_family(compute_capability):
major, minor = compute_capability
if major == 8:
return "Ampere"
elif major == 7:
if minor == 5:
return "Turing"
else:
return "Volta"
elif major == 6:
return "Pascal"
elif major == 5:
return "Maxwell"
elif major == 3:
return "Kepler"
elif major == 2:
return "Fermi"
elif major == 1:
return "Tesla"
elif major == 9:
return "Hopper"
elif major == 8 and minor >= 9:
return "Ada Lovelace"
else:
raise RuntimeError("Non-NVIDIA GPU found.")
# Check the number of available GPUs and display relevant messages
num_gpus = torch.cuda.device_count()
display_message(accelerator, "Detecting CUDA devices...")
for i in range(num_gpus):
device_name = torch.cuda.get_device_name(i)
compute_capability = torch.cuda.get_device_capability(i)
gpu_family = get_gpu_family(compute_capability)
display_message(
accelerator,
f"CUDA:{i} - {device_name} - Compute Capability: {compute_capability[0]}.{compute_capability[1]} - Family: {gpu_family}",
)
if num_gpus == 0:
raise RuntimeError("No CUDA devices are available")
elif num_gpus == 1:
display_message(accelerator, "Distributed training has been disabled.")
elif num_gpus > 1 and not os.environ.get("LOCAL_RANK"):
raise RuntimeError(
f"Found multiple CUDA devices ({num_gpus}). Rerun this script using the command: `torchrun --nproc_per_node={num_gpus} script_name.py`"
)
def cleanup_distributed():
# Shut down the training node
if is_dist_available() and torch.cuda.device_count() > 1:
destroy_process_group()
def load_and_prepare_dataset(tokenizer):
# Load the dataset
dataset = load_dataset(os.environ["HF_DATASET_NAME"])
# Split the dataset into training and evaluation sets
# Some data will be taken for evaluation purposes.
train_test_split = dataset["train"].train_test_split(
test_size=0.01, seed=random.randint(0, 10000)
)
# Access the training and evaluation datasets
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"].select(range(10))
def preprocess_function(examples):
# Initialize new array
texts = []
# Format conversations in the correct way
# If you are getting a KeyError, check the conversation row name.
for entry in examples["conversations"]:
template = ""
for message in entry:
template += f"<|im_start|>{message['role']}\n{message['content']}\n<|im_end|>\n"
texts.append(template)
# Tokenize the texts
tokenized = tokenizer(
texts,
truncation=True,
max_length=2048,
padding="max_length",
return_tensors="pt",
)
# Set up the labels for language modeling
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
# Process the training dataset
tokenized_train_dataset = train_dataset.map(
preprocess_function,
batched=True,
remove_columns=train_dataset.column_names,
)
# Process the evaluation dataset
tokenized_eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
remove_columns=eval_dataset.column_names,
)
return tokenized_train_dataset, tokenized_eval_dataset
def prepare_model_for_training(model_name, local_rank):
# Check if all GPUs are Ampere or newer
def is_ampere_or_newer(device_index):
capability = torch.cuda.get_device_capability(device_index)
return capability[0] >= 8 # Ampere is compute capability 8.0 or higher
# Determine if ALL GPUs are Ampere or newer
all_gpus_ampere = all(is_ampere_or_newer(i) for i in range(torch.cuda.device_count()))
try:
# Try loading the model with Flash Attention 2 only if ALL GPUs are Ampere or newer
if all_gpus_ampere:
if torch.cuda.device_count() > 1:
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
),
device_map=None,
attn_implementation="flash_attention_2", # Enable Flash Attention 2
trust_remote_code=True
)
model = model.to(f"cuda:{local_rank}")
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
),
device_map="auto",
attn_implementation="flash_attention_2", # Enable Flash Attention 2
trust_remote_code=True
)
print(f"Flash Attention 2 turned on for current device: cuda:{local_rank}")
else:
raise ValueError("Flash Attention 2 unavailable for current device")
except ValueError:
# Fallback to default attention if Flash Attention 2 is not supported
print(f"Flash Attention 2 turned off for current device: cuda:{local_rank}")
if torch.cuda.device_count() > 1:
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
),
device_map=None,
trust_remote_code=True
)
model = model.to(f"cuda:{local_rank}")
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
),
device_map="auto",
trust_remote_code=True
)
try:
model.enable_xformers_memory_efficient_attention()
print(f"xFormers turned on for current device: cuda:{local_rank}")
except AttributeError:
print("xFormers turned off because the current model does not support it.")
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)
# Configure LoRA
lora_config = LoraConfig(
# Use 16 for definining new data, 128 for when the model fails to capture overrides
r=16,
# Setting the ratio to be 2:1 is recommended
lora_alpha=32,
target_modules="all-linear",
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
use_rslora=True,
)
# Get PEFT model
model = get_peft_model(model, lora_config)
return model
def merge_and_save_bf16(accelerator, adapter_model_path, base_model_name, output_dir):
# Load to CPU to save VRAM during exporting
# You need a significant amount of memory to load large models as BF16.
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True
)
# Load the LoRA model
adapter_model = PeftModel.from_pretrained(base_model, adapter_model_path)
display_message(accelerator, "Merging...")
# Merge the LoRA weights with the base model
merged_model = adapter_model.merge_and_unload()
display_message(accelerator, "Casting...")
# Convert to BF16
merged_model = merged_model.to(torch.bfloat16)
# Save the merged model in HF split format
display_message(accelerator, "Exporting...")
merged_model.save_pretrained(f"{output_dir}_merged_bf16", max_shard_size="5GB")
return f"{output_dir}_merged_bf16"
def main():
# Setup the distributed training node and accelerator
local_rank = setup_distributed()
with open("ds_config.json", "r") as f:
ds_config = json.load(f)
ds_plugin = DeepSpeedPlugin(
hf_ds_config=ds_config
)
accelerator = Accelerator(
gradient_accumulation_steps=4,
mixed_precision="fp16",
deepspeed_plugin={
"deepspeed_config": ds_plugin,
}
)
check_gpus(accelerator)
# Load tokenizer and model
model_name = os.environ["HF_BASE_MODEL_NAME"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set the padding token explicitly, it may be unset in some cases
tokenizer.pad_token = tokenizer.eos_token
# Some tokenizers may break with left-side padding, use right-side padding instead
tokenizer.padding_side = "right"
# Load model with 4-bit quantization using the quantization config
model = prepare_model_for_training(model_name, local_rank)
# Prepare dataset
tokenized_train_dataset, tokenized_eval_dataset = load_and_prepare_dataset(tokenizer)
# Set up training arguments
training_args = SFTConfig(
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
num_train_epochs=float(os.environ.get("N_EPOCHS", 1)),
max_steps=int(os.environ.get("N_STEPS", 0)),
learning_rate=2e-4 * torch.cuda.device_count(),
fp16=True,
logging_steps=1,
eval_strategy="steps",
eval_steps=50,
optim="adamw_torch_fused", # auto-corrects in case gradient norms are NaN
weight_decay=0.01,
lr_scheduler_type="cosine",
seed=random.randint(0, 10000),
output_dir="./checkpoints",
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
logging_dir="logs",
max_seq_length=2048,
report_to="tensorboard",
local_rank=local_rank,
ddp_find_unused_parameters=False,
save_total_limit=1,
)
# Initialize trainer
trainer = accelerator.prepare(SFTTrainer(
model=model,
args=training_args,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_eval_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
))
# Check if a checkpoint exists and load it
checkpoint_dir = "../input/checkpoints" # Compatible with Kaggle's working directory format
if os.path.exists(checkpoint_dir):
display_message(accelerator, f"Resuming training from checkpoint: {checkpoint_dir}")
trainer.train(resume_from_checkpoint=checkpoint_dir)
else:
display_message(accelerator, "No checkpoint found. Starting training from scratch.")
trainer.train()
# Save the quantized model
if accelerator.is_main_process:
# Save the LoRA adapters
adapter_path = "./final_lora_adaptors"
trainer.model.save_pretrained(adapter_path)
display_message(accelerator, "Saving merged model...")
# Move to CPU for merging
trainer.model.cpu()
torch.cuda.empty_cache()
# Merge LoRA weights with base model and save
bf16_model_path = merge_and_save_bf16(accelerator, adapter_path, model_name, "./final_model")
display_message(accelerator, "Model saved")
# Save tokenizer
tokenizer.save_pretrained(adapter_path)
tokenizer.save_pretrained(bf16_model_path)
display_message(accelerator, "Tokenizers saved")
# Initialize Hugging Face API
api = HfApi()
# Define your Hugging Face username and model names
username = os.environ["HF_USERNAME"] # Populate this with your Hugging Face username
merged_model_name = os.environ["HF_MODEL_NAME"] # Populate this with the desired model upload name
adapter_model_name = os.environ["HF_ADAPTER_NAME"] # Populate this with your desired adapter upload name
should_upload = False # Set to True to upload model automatically (requires token)
should_upload_lora = False # Set to True to upload adapter automatically (requires token)
if should_upload:
# Upload the merged model
display_message(accelerator, "Uploading merged model...")
create_repo(merged_model_name, private=False, exist_ok=True)
api.upload_folder(
folder_path=bf16_model_path,
repo_id=f"{username}/{merged_model_name}",
commit_message="Upload merged BF16 model",
)
display_message(accelerator,
f"Merged model uploaded\nhttps://huggingface.co/{username}/{merged_model_name}"
)
if should_upload_lora:
# Upload the LoRA adapters
display_message(accelerator, "Uploading LoRA adapters...")
create_repo(adapter_model_name, private=False, exist_ok=True)
api.upload_folder(
folder_path=adapter_path,
repo_id=f"{username}/{adapter_model_name}",
commit_message="Upload LoRA adapters",
)
display_message(accelerator,
f"LoRA adapters uploaded\nhttps://huggingface.co/{username}/{adapter_model_name}"
)
# Cleanup
display_message(accelerator, "Training completed.")
cleanup_distributed()
if __name__ == "__main__":
main()
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto"
},
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment