Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Last active September 19, 2024 12:10
Show Gist options
  • Save younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da to your computer and use it in GitHub Desktop.
Save younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da to your computer and use it in GitHub Desktop.
Fine tune Llama v2 models on Guanaco Dataset
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
AutoTokenizer,
TrainingArguments,
)
from trl import SFTTrainer
# This example fine-tunes Llama v2 model on Guanace dataset
# using QLoRA. At the end of the script we perform merging the weights
# Use it by correctly passing --model_name argument when running the
# script.
#
# Versions used:
# accelerate == 0.21.0
# peft == 0.4.0
# bitsandbytes == 0.40.2
# transformers == 4.31.0
# trl == 0.4.7
# For models that have `config.pretraining_tp > 1` install:
# pip install git+https://github.com/huggingface/transformers.git
@dataclass
class ScriptArguments:
"""
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
"""
local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"})
per_device_train_batch_size: Optional[int] = field(default=4)
per_device_eval_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=4)
learning_rate: Optional[float] = field(default=2e-4)
max_grad_norm: Optional[float] = field(default=0.3)
weight_decay: Optional[int] = field(default=0.001)
lora_alpha: Optional[int] = field(default=16)
lora_dropout: Optional[float] = field(default=0.1)
lora_r: Optional[int] = field(default=64)
max_seq_length: Optional[int] = field(default=512)
model_name: Optional[str] = field(
default="meta-llama/Llama-2-7b-hf",
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
}
)
dataset_name: Optional[str] = field(
default="timdettmers/openassistant-guanaco",
metadata={"help": "The preference dataset to use."},
)
use_4bit: Optional[bool] = field(
default=True,
metadata={"help": "Activate 4bit precision base model loading"},
)
use_nested_quant: Optional[bool] = field(
default=False,
metadata={"help": "Activate nested quantization for 4bit base models"},
)
bnb_4bit_compute_dtype: Optional[str] = field(
default="float16",
metadata={"help": "Compute dtype for 4bit base models"},
)
bnb_4bit_quant_type: Optional[str] = field(
default="nf4",
metadata={"help": "Quantization type fp4 or nf4"},
)
num_train_epochs: Optional[int] = field(
default=1,
metadata={"help": "The number of training epochs for the reward model."},
)
fp16: Optional[bool] = field(
default=False,
metadata={"help": "Enables fp16 training."},
)
bf16: Optional[bool] = field(
default=False,
metadata={"help": "Enables bf16 training."},
)
packing: Optional[bool] = field(
default=False,
metadata={"help": "Use packing dataset creating."},
)
gradient_checkpointing: Optional[bool] = field(
default=True,
metadata={"help": "Enables gradient checkpointing."},
)
optim: Optional[str] = field(
default="paged_adamw_32bit",
metadata={"help": "The optimizer to use."},
)
lr_scheduler_type: str = field(
default="constant",
metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"},
)
max_steps: int = field(default=10000, metadata={"help": "How many optimizer update steps to take"})
warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"})
group_by_length: bool = field(
default=True,
metadata={
"help": "Group sequences into batches with same length. Saves memory and speeds up training considerably."
},
)
save_steps: int = field(default=10, metadata={"help": "Save checkpoint every X updates steps."})
logging_steps: int = field(default=10, metadata={"help": "Log every X updates steps."})
merge_and_push: Optional[bool] = field(
default=False,
metadata={"help": "Merge and push weights after training"},
)
output_dir: str = field(
default="./results",
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
def create_and_prepare_model(args):
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=args.use_4bit,
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=args.use_nested_quant,
)
if compute_dtype == torch.float16 and args.use_4bit:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
print("=" * 80)
# Load the entire model on the GPU 0
# switch to `device_map = "auto"` for multi-GPU
device_map = {"": 0}
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
quantization_config=bnb_config,
device_map=device_map,
use_auth_token=True
)
# check: https://github.com/huggingface/transformers/pull/24906
model.config.pretraining_tp = 1
peft_config = LoraConfig(
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
r=script_args.lora_r,
bias="none",
task_type="CAUSAL_LM",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
return model, peft_config, tokenizer
training_arguments = TrainingArguments(
output_dir=script_args.output_dir,
per_device_train_batch_size=script_args.per_device_train_batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
optim=script_args.optim,
save_steps=script_args.save_steps,
logging_steps=script_args.logging_steps,
learning_rate=script_args.learning_rate,
fp16=script_args.fp16,
bf16=script_args.bf16,
max_grad_norm=script_args.max_grad_norm,
max_steps=script_args.max_steps,
warmup_ratio=script_args.warmup_ratio,
group_by_length=script_args.group_by_length,
lr_scheduler_type=script_args.lr_scheduler_type,
)
model, peft_config, tokenizer = create_and_prepare_model(script_args)
model.config.use_cache = False
dataset = load_dataset(script_args.dataset_name, split="train")
# Fix weird overflow issue with fp16 training
tokenizer.padding_side = "right"
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
packing=script_args.packing,
)
trainer.train()
if script_args.merge_and_push:
output_dir = os.path.join(script_args.output_dir, "final_checkpoints")
trainer.model.save_pretrained(output_dir)
# Free memory for merging weights
del model
torch.cuda.empty_cache()
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()
output_merged_dir = os.path.join(script_args.output_dir, "final_merged_checkpoint")
model.save_pretrained(output_merged_dir, safe_serialization=True)
@tytung2020
Copy link

tytung2020 commented Aug 4, 2023

What is the best way to load a pretrained checkpoint to continue training?

edit: trainer.train(resume_from_checkpoint=True) does the job.

I did that but it gives this error:
ValueError: No valid checkpoint found in output directory

This error occurs both when the model is finetuned for the first time, or when there are already checkpoint folders in the directory, from earlier finetune process (but where the code was using trainer.train() ).

@huseinzol05
Copy link

huseinzol05 commented Aug 8, 2023

Currently the model is very bad to generate <EOS> token to stop early, this is because we set tokenizer.pad_token = tokenizer.eos_token, and because of this, the collactor will replaced pad_token_id with -100, https://github.com/huggingface/transformers/blob/main/src/transformers/data/data_collator.py#L747, this will ignore <EOS> token in the loss function,

# input ids

tensor([    1,   396,  2659, 29901,  1426,   421, 10994,   669,  1160, 29936,
         7197,  7250, 29892, 20147,  5834, 29920, 20962, 19548,   277, 29017,
         6836, 29889,   897, 29895,   271,   413,   287,  1794,   365,  5921,
          470,  6940,   574,   413,   532, 29892,  1401,   273,  4971, 12178,
         2849,   435,  9010, 29892,   594, 29874,  1673, 11608,   675,   278,
         1426,   411,  3858,  6024, 22198,   742,   525,  1066,  3321,   742,
          525, 17821,  1705,  2033,   322,  5649, 29892,   736,   408,  4663,
         1820, 11117, 18616,  2073,   742,   525,  4548,  7420, 29918,   264,
          742,   525,  4548,  7420, 29918,  1516, 10827,    13, 29937, 29933,
          327, 29901,   426,    13,  1678,   376, 18616,  2073,  1115,   376,
        17821,  1705,   613,    13,  1678,   376,  4548,  7420, 29918,   264,
         1115,   376,  1576,  1426,   338, 21104,   408,   372,   947,   451,
         4653,   738,  4549,  6374,   470,  8178,   953,  8194, 29889,   739,
          338,  3763, 13138,  2472,  1048,   263,  4423, 19602,    13,  1678,
          376,  4548,  7420, 29918,  1516,  1115,   376, 29911, 14541,   297,
        29875,   594,   284,   801, 21104, 13023,  1648, 29871,   423, 10668,
          557,   286,   996,   686, 21474, 11052,   953,  8156, 13686,   361,
          472,   585,  3480, 25572,   343,   574,   413, 29884,   271, 29889,
          306, 29874,   298, 20912,  4509,  7941,  2136, 29880,   398,   271,
        12033,   574, 25194,  6840,  1213,    13, 29913,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2])

# labels

tensor([    1,   396,  2659, 29901,  1426,   421, 10994,   669,  1160, 29936,
         7197,  7250, 29892, 20147,  5834, 29920, 20962, 19548,   277, 29017,
         6836, 29889,   897, 29895,   271,   413,   287,  1794,   365,  5921,
          470,  6940,   574,   413,   532, 29892,  1401,   273,  4971, 12178,
         2849,   435,  9010, 29892,   594, 29874,  1673, 11608,   675,   278,
         1426,   411,  3858,  6024, 22198,   742,   525,  1066,  3321,   742,
          525, 17821,  1705,  2033,   322,  5649, 29892,   736,   408,  4663,
         1820, 11117, 18616,  2073,   742,   525,  4548,  7420, 29918,   264,
          742,   525,  4548,  7420, 29918,  1516, 10827,    13, 29937, 29933,
          327, 29901,   426,    13,  1678,   376, 18616,  2073,  1115,   376,
        17821,  1705,   613,    13,  1678,   376,  4548,  7420, 29918,   264,
         1115,   376,  1576,  1426,   338, 21104,   408,   372,   947,   451,
         4653,   738,  4549,  6374,   470,  8178,   953,  8194, 29889,   739,
          338,  3763, 13138,  2472,  1048,   263,  4423, 19602,    13,  1678,
          376,  4548,  7420, 29918,  1516,  1115,   376, 29911, 14541,   297,
        29875,   594,   284,   801, 21104, 13023,  1648, 29871,   423, 10668,
          557,   286,   996,   686, 21474, 11052,   953,  8156, 13686,   361,
          472,   585,  3480, 25572,   343,   574,   413, 29884,   271, 29889,
          306, 29874,   298, 20912,  4509,  7941,  2136, 29880,   398,   271,
        12033,   574, 25194,  6840,  1213,    13, 29913,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100])

labels does not have EOS, or 2.

@nehzata
Copy link

nehzata commented Aug 11, 2023

@huseinzol05 & @younesbelkada I came across the same problem with fine tuned models not being able to generated EOS tokens. The resulting behaviour is that model.generate() only stops at max_new_tokens and just rambles on.

This modification seems to have solved the problem on my side:

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = "<PAD>"
tokenizer.padding_side = "right"

Further discussions in huggingface/transformers#22794. llama-recipes does something similar.

@younesbelkada would appreciate if you could clarify if this is the correct approach or advise a better solution please. Many thanks for making this gist available btw! Works awesome otherwise :)

@kochhar
Copy link

kochhar commented Aug 11, 2023

@nehzata @kw2828 I tried this setting with the guardrail sharded 4bit model to run on a Google Colab with 12GB/16GB RAM/VRAM but when I do this the training fails with a CUDA Device side assert triggered.

Any suggestions on what to look into to identify the cause of this?

@nehzata
Copy link

nehzata commented Aug 11, 2023

@kochhar Yes I came across that error as well. Don't recall if it was related to LlamaTokenizer/LlamaForCausalLM or when using add_special_tokens()... However tokenizer.pad_token = "<PAD>" seems to work without issues.

@kochhar
Copy link

kochhar commented Aug 12, 2023

@nehzata Thanks, yes that works 👍!!

@mkserge
Copy link

mkserge commented Aug 31, 2023

The 0 loss has been reproduced and it seems to be fixed by adding tokenizer.padding_side = "right" before the init of SFTTrainer, some weird overflow happens when using left padding side with llama-2 7b model but I can confirm this fixed things on my side for fp16 training. Fixed the gist accordingly

@younesbelkada isn't the recommendation (see here, for example) to do exactly the opposite? Pad on the left, for decoder-only models, since you don't want the model to continue predicting from PAD tokens?

@shubhamagarwal92
Copy link

Hi! Thanks for open-sourcing the gist! Could you please also share some scripts for batch generation from the model (using pipeline or model.generate)?

@swaroop11
Copy link

Hi @younesbelkada, thanks for providing this script. I am using this to fine-tune my model on 4k Questions and Answers pairs which I converted into LLAMA2 prompt format before training. Training is going fine and loss is also reducing but at the time of inference, results are not good.
I have trained the llama2 7b chat hf model. Can you please provide some suggestions which can help me in improvements of results

@nehzata Thanks for your suggestion it worked and helped me with my training, now model stops predictions.

@NPap0
Copy link

NPap0 commented Oct 4, 2023

Hi @younesbelkada, thanks for providing this script. I am using this to fine-tune my model on 4k Questions and Answers pairs which I converted into LLAMA2 prompt format before training. Training is going fine and loss is also reducing but at the time of inference, results are not good. I have trained the llama2 7b chat hf model. Can you please provide some suggestions which can help me in improvements of results

@nehzata Thanks for your suggestion it worked and helped me with my training, now model stops predictions.

@swaroop11 Hey there, by results are not good did you mean that the predictions never stopped because of the EOS token issue and you got it fixed or was it another issue?

Hi, I am using 2 T4 GPUs on Kaggle for fine-tuning sharded Llama 2 model with 4-bit quantization, but still running into CUDA Out of memory issues, any fix other than upgrading the GPU? Tried varying with different batch sizes and gradients over steps.

@inamdarmihir
Hey, I am about to try the same thing, did you find any workarounds for this?

@xinmengZ
Copy link

xinmengZ commented Oct 5, 2023

@younesbelkada appreciate the code! I wanted to train llama2-70B on 2 A100 80G GPUs on a Databricks cluster. After changing device_map to 'auto', I still got this error. You can't train a model that has been loaded in 8-bit precision on multiple devices in any distributed mode. In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism. Therefore you should not specify that you are under any distributed regime in your accelerate config. Any suggestions why this is happening?

@swaroop11
Copy link

Hi @OneCodeToRuleThemAll, by results are not good I mean that the model is not predicting results from my training data, which means if I ask a test question (similar to one of the training questions) model generates a completely different answer (I think it is answering using it's old weights/learning before I performed fine-tuning). Answers from the model are not related to the data on which it is being fine-tuned but rather different responses.
The EOS issue is fixed for me with the help of the above-mentioned suggestion. Thanks for that.

@NPap0
Copy link

NPap0 commented Oct 5, 2023

Hi @OneCodeToRuleThemAll, by results are not good I mean that the model is not predicting results from my training data, which means if I ask a test question (similar to one of the training questions) model generates a completely different answer (I think it is answering using it's old weights/learning before I performed fine-tuning). Answers from the model are not related to the data on which it is being fine-tuned but rather different responses. The EOS issue is fixed for me with the help of the above-mentioned suggestion. Thanks for that.

@swaroop11
Thanks for your response!
Do you have any theories as to why training did not perform as expected? Maybe few data/ small number of epochs or any other of the gazillion parameters of the training? Hahaha
Was the dataset clean?
I'm just trying to see if people have made this work for their use case.

I've encountered the 0 loss issue and when trying to fix that I encounter the non stopping predictions until the max_generated_tokens limit is reached

@sanipanwala
Copy link

Hello,

Is anyone facing the below issue when this script is running on multi GPU? I have set the local_rank parameter as 2.

operator(): block: [36,0,0], thread: [26,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment