Training Gemma 2B model on a single device using multiple datasets.
git clone https://github.com/EvilFreelancer/torchtune.git
cd torchtune
git switch feat-concatenate-datasets
python3 -m venv venv
source venv/bin/activate
pip install .
pip install wandb
tune download google/gemma-2b --output-dir=./gemma --ignore-patterns=''
tune cp gemma/2B_lora ./gemma_2B_lora_single_device.yaml
sed -r 's#/tmp/#./#g' -i ./gemma_2B_lora_single_device.yamlNext need to replace metric_legger section:
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
project: gemma_2B_lora_single_deviceNext need to replace dataset section:
dataset:
- _component_: torchtune.datasets.instruct_dataset
source: vicgalle/alpaca-gpt4
template: AlpacaInstructTemplate
split: train
train_on_input: True
- _component_: torchtune.datasets.instruct_dataset
source: samsum
template: SummarizeTemplate
column_map: {"output": "summary"}
split: train
train_on_input: False
seed: null
shuffle: TrueNow we can start training:
tune run lora_finetune_single_device --config ./gemma_2B_lora_single_device.yaml