Skip to content

Instantly share code, notes, and snippets.

@EvilFreelancer
Last active April 30, 2024 11:11
Show Gist options
  • Select an option

  • Save EvilFreelancer/3287c8730b00fafe471724bb4658b01b to your computer and use it in GitHub Desktop.

Select an option

Save EvilFreelancer/3287c8730b00fafe471724bb4658b01b to your computer and use it in GitHub Desktop.
TorchTune: LoRA miltidataset training of Gemma 2B model on a single device

TorchTune / Gemma 2B / LoRA / Single Device

Training Gemma 2B model on a single device using multiple datasets.

Prepare environment

git clone https://github.com/EvilFreelancer/torchtune.git
cd torchtune
git switch feat-concatenate-datasets
python3 -m venv venv
source venv/bin/activate
pip install .
pip install wandb
tune download google/gemma-2b --output-dir=./gemma --ignore-patterns=''
tune cp gemma/2B_lora ./gemma_2B_lora_single_device.yaml
sed -r 's#/tmp/#./#g' -i ./gemma_2B_lora_single_device.yaml

Next need to replace metric_legger section:

metric_logger:
  _component_: torchtune.utils.metric_logging.WandBLogger
  project: gemma_2B_lora_single_device

Next need to replace dataset section:

dataset:
  - _component_: torchtune.datasets.instruct_dataset
    source: vicgalle/alpaca-gpt4
    template: AlpacaInstructTemplate
    split: train
    train_on_input: True
  - _component_: torchtune.datasets.instruct_dataset
    source: samsum
    template: SummarizeTemplate
    column_map: {"output": "summary"}
    split: train
    train_on_input: False
seed: null
shuffle: True

Now we can start training:

tune run lora_finetune_single_device --config ./gemma_2B_lora_single_device.yaml
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: ./gemma/tokenizer.model
# Dataset
dataset:
- _component_: torchtune.datasets.instruct_dataset
source: vicgalle/alpaca-gpt4
template: AlpacaInstructTemplate
split: train
train_on_input: True
- _component_: torchtune.datasets.instruct_dataset
source: samsum
template: SummarizeTemplate
column_map: {"output": "summary"}
split: train
train_on_input: False
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.gemma.lora_gemma_2b
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: ./gemma/
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: ./gemma
model_type: GEMMA
resume_from_checkpoint: False
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torch.nn.CrossEntropyLoss
# Fine-tuning arguments
batch_size: 2
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 4
compile: False
# Training env
device: cuda
# Memory management
enable_activation_checkpointing: True
# Reduced precision
dtype: bf16
# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
project: gemma_2B_lora_single_device
output_dir: ./alpaca-gemma-lora
log_every_n_steps: 1
log_peak_memory_stats: False
# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
_component_: torchtune.utils.profiler
enabled: False
output_dir: ./alpaca-gemma-finetune/torchtune_perf_tracing.json
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment