This document explains in detail how we fine-tune the NousResearch/Llama-2-7b-chat-hf
model on a financial tweet sentiment dataset using the QLoRA method. The training is done using Hugging Face Transformers, PEFT (LoRA), and bitsandbytes for 4-bit quantization.
model_name = "NousResearch/Llama-2-7b-chat-hf"
dataset_name = "BloomTech/finance-tweet-sentiment-llama2-1k"
new_model = "llama2-finance-tweet-sentiment-finetune"
model_name
: We are fine-tuning a pre-trained LLaMA-2 7B chat model.dataset_name
: The dataset consists of finance-related tweets labeled for sentiment analysis.new_model
: Name under which we will save the fine-tuned model.
lora_r = 64
lora_alpha = 16
lora_dropout = 0.1
- LoRA (Low-Rank Adaptation) allows training only a small set of parameters:
r
: Rank of the low-rank matrices.alpha
: Scaling factor for updates.dropout
: Regularization to prevent overfitting.
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
- Enables 4-bit quantized model loading using
bitsandbytes
, making large models fit into smaller GPUs. nf4
is a more accurate quantization format thanfp4
.float16
is used for computation.nested quantization
is turned off for simplicity.
output_dir = "./results"
num_train_epochs = 1
per_device_train_batch_size = 4
gradient_accumulation_steps = 1
learning_rate = 2e-4
optim = "paged_adamw_32bit"
lr_scheduler_type = "cosine"
warmup_ratio = 0.03
group_by_length = True
logging_steps = 25
- Basic setup for supervised fine-tuning.
- Gradient checkpointing and accumulation used to reduce memory usage.
- Cosine LR schedule with warmup helps avoid large jumps at the start.
- Optimizer is memory-efficient
paged_adamw_32bit
.
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# Load dataset
dataset = load_dataset(dataset_name, split="train")
# Quantization config
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_nested_quant,
)
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
- Loads the model in quantized 4-bit format.
- Tokenizer is customized for LLaMA.
from peft import LoraConfig
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
)
- Applies LoRA adapters to the attention layers.
- Only these adapters are trained, keeping the base model weights frozen.
from transformers import TrainingArguments
training_arguments = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
optim=optim,
save_steps=0,
logging_steps=logging_steps,
learning_rate=learning_rate,
weight_decay=0.001,
fp16=False,
bf16=False,
max_grad_norm=0.3,
max_steps=-1,
warmup_ratio=warmup_ratio,
group_by_length=group_by_length,
lr_scheduler_type=lr_scheduler_type,
report_to="tensorboard"
)
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=None,
tokenizer=tokenizer,
args=training_arguments,
packing=False,
)
trainer.train()
- This wraps everything into a single training loop using the
SFTTrainer
. - It automatically tokenizes, applies LoRA, and logs metrics.
With this setup, you're efficiently fine-tuning a massive LLM on a GPU with limited memory using:
- β¨ QLoRA (Quantized LoRA): 4-bit quantized models + low-rank updates
- π LoRA adapters: Train just the task-specific parts
- βοΈ Gradient checkpointing and memory-efficient optimizers
- π Financial sentiment tweets: Real-world, targeted task