-
-
Save davideuler/8e5647e5aaadded867d4ed93d75a43c4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "1055c3f3", | |
"metadata": {}, | |
"source": [ | |
"### LoRA Fine-Tuning with MLX LM\n", | |
"\n", | |
"In this notebook, we'll walk through how to [LoRA fine-tune](https://arxiv.org/abs/2106.09685) an LLM with MLX LM. We'll use the [HellaSwag](https://rowanzellers.com/hellaswag/) dataset for common sense reasoning as an example. An outline:\n", | |
"\n", | |
"1. Download the dataset and prepare it in the right format for MLX LM.\n", | |
"2. Setup and run LoRA training. We'll show how to capture the training logs and plot some statistics to visualize the performance.\n", | |
"3. Evaluate on the test set. We'll compute the final question-answer accuracy of the fine-tuned model.\n", | |
"4. Fuse the resulting adapters into the base model and upload to Hugging Face.\n", | |
"5. Discuss tips for debugging accuracy and efficiency." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "21397627", | |
"metadata": {}, | |
"source": [ | |
"### Install dependencies" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "664272fb", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install mlx-lm\n", | |
"!pip install matplotlib" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "dd27c693", | |
"metadata": {}, | |
"source": [ | |
"### Preprocess Data\n", | |
"We'll start by downloading an already pre-processed version of the HellaSwag dataset from [LLM-Adapters](https://github.com/AGI-Edgerunners/LLM-Adapters)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "61698208", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"HellaSwag stats: 39905 training examples and 10042 test examples.\n", | |
"An example:\n", | |
"\n", | |
"{\n", | |
" \"instruction\": \"Please choose the correct ending to complete the given sentence: Removing ice from car: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then\\n\\nEnding1: , the man adds wax to the windshield and cuts it. Ending2: , a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled. Ending3: , the man puts on a christmas coat, knitted with netting. Ending4: , the man continues removing the snow on his car.\\n\\nAnswer format: ending1/ending2/ending3/ending4\",\n", | |
" \"input\": \"\",\n", | |
" \"output\": \"the correct answer is ending4\",\n", | |
" \"answer\": \"ending4\"\n", | |
"}\n" | |
] | |
} | |
], | |
"source": [ | |
"import json\n", | |
"import numpy as np\n", | |
"from pathlib import Path\n", | |
"from urllib import request\n", | |
"\n", | |
"save_dir = \"/tmp/hellaswag\"\n", | |
"\n", | |
"def download_and_save(save_dir):\n", | |
" base_url = \"https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/main/dataset/hellaswag/\"\n", | |
" save_dir = Path(save_dir)\n", | |
" save_dir.mkdir(parents=True, exist_ok=True)\n", | |
" for name in [\"train.json\", \"test.json\"]:\n", | |
" out_file = save_dir / name\n", | |
" if not out_file.exists():\n", | |
" request.urlretrieve(base_url + name, out_file)\n", | |
"\n", | |
"def load_json(dataset):\n", | |
" download_and_save(save_dir)\n", | |
" with open(f\"{save_dir}/{dataset}.json\", \"r\") as fid:\n", | |
" return json.load(fid)\n", | |
"\n", | |
"train_set, test_set = load_json(\"train\"), load_json(\"test\")\n", | |
"print(f\"HellaSwag stats: {len(train_set)} training examples and {len(test_set)} test examples.\")\n", | |
"print(\"An example:\\n\")\n", | |
"print(json.dumps(train_set[0], indent=4))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9a514d79", | |
"metadata": {}, | |
"source": [ | |
"Next, let's split the training set into a training and a validation set. We'll pull out a randomly chosen 10% for validation." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "9b607237", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Seed for reproducibility\n", | |
"np.random.seed(43)\n", | |
"perm = np.random.permutation(len(train_set))\n", | |
"valid_size = int(0.1 * len(train_set))\n", | |
"valid_set = [train_set[i] for i in perm[:valid_size]]\n", | |
"train_set = [train_set[i] for i in perm[valid_size:]]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "35c38c4e", | |
"metadata": {}, | |
"source": [ | |
"Finally, put the data splits in the MLX LM training format. The format simply expects the data to be in a container which supports random access to the individual examples (e.g. a Python `list`):\n", | |
"```\n", | |
"[\"An example for the model.\", \"Another example for the model.\", ...]\n", | |
"```\n", | |
"For more details, see the [documentation on supported formats](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#Data)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "ea738f2b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def preprocess(dataset):\n", | |
" return [t[\"instruction\"] + \"\\n\" + t[\"output\"] for t in dataset]\n", | |
"\n", | |
"train_set, valid_set, test_set = map(preprocess, (train_set, valid_set, test_set))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b259eb69", | |
"metadata": {}, | |
"source": [ | |
"### Fine-Tune\n", | |
"\n", | |
"For fine-tuning, we'll use Microsoft's [Phi-3 mini](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). At 3.8 billion parameters, Phi-3 mini is a high-quality model that is also fast to fine-tune on most Apple silicon machines. Also, it has a [permissive MIT License](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE).\n", | |
"\n", | |
"First, import all the packages and functions we need." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "c3ff309a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"import mlx.core as mx\n", | |
"import mlx.optimizers as optim\n", | |
"from mlx.utils import tree_flatten\n", | |
"from mlx_lm import load, generate\n", | |
"from mlx_lm.tuner import train, evaluate, TrainingArgs\n", | |
"from mlx_lm.tuner import linear_to_lora_layers\n", | |
"import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "87628d24", | |
"metadata": {}, | |
"source": [ | |
"Next, setup the LoRA parameters and make the training arguments. See the [training argument class](https://github.com/ml-explore/mlx-examples/blob/81318ad4a8b2ca5fd1431a42db2b0244d16be851/llms/mlx_lm/tuner/trainer.py#L31-L63) for a more detailed list of training parameters. \n", | |
"\n", | |
"Recall the LoRA update is $W^\\top \\mathbf{x} + c \\cdot \\mathbf{a} \\mathbf{b}^\\top \\mathbf{x}$ where $\\mathbf{a}$ has shape `(D, rank)`.\n", | |
"\n", | |
"With that in mind, the LoRA parameters to attend to are:\n", | |
"- `lora_layers`: The number of Transformer blocks from the top of the model to adapt.\n", | |
"- `rank`: The rank of the low-rank adapters. A larger rank implies more adapter parameters per linear layer.\n", | |
"- `scale`: This is the constant $c$ that scales the low-rank update." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "f0851dc8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Make a directory to save the adapter config and weights\n", | |
"adapter_path = Path(\"adapters\")\n", | |
"adapter_path.mkdir(parents=True, exist_ok=True)\n", | |
"\n", | |
"lora_config = {\n", | |
" \"lora_layers\": 8,\n", | |
" \"lora_parameters\": {\n", | |
" \"rank\": 8,\n", | |
" \"scale\": 20.0,\n", | |
" \"dropout\": 0.0,\n", | |
"}}\n", | |
"\n", | |
"# Save the LoRA config to the adapter path\n", | |
"with open(adapter_path / \"adapter_config.json\", \"w\") as fid:\n", | |
" json.dump(lora_config, fid, indent=4) \n", | |
"\n", | |
"training_args = TrainingArgs(\n", | |
" adapter_file=adapter_path / \"adapters.safetensors\",\n", | |
" iters=200,\n", | |
" steps_per_eval=50\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "56fefd19", | |
"metadata": {}, | |
"source": [ | |
"Next, load the Phi-3 mini model. Note this may take a few minutes to download from HuggingFace if you haven't downloaded it before." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "fb0b16f2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_path = \"microsoft/Phi-3-mini-4k-instruct\"\n", | |
"model, tokenizer = load(model_path)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6609c92a", | |
"metadata": {}, | |
"source": [ | |
"After loading the model, freeze it's parameters so we don't train them. Then convert linear layers to LoRA layers using the MLX LM utility `linear_to_lora_layers`. The adapters in the `LoRA` layers are not frozen, so they will be included in the model's `trainable_parameters`. Check-out the [LoRA layer implementation](https://github.com/ml-explore/mlx-examples/blob/81318ad4a8b2ca5fd1431a42db2b0244d16be851/llms/mlx_lm/tuner/lora.py#L72-L104) to see how it all works.\n", | |
"\n", | |
"By default, MLX LM only adapts the query, key, and value projection matrices for Phi-3. You can specify the layers to adapt by setting `lora_parameters[\"keys\"]` to a list of layer names. In this case it defaults to `[\"attn.qkv_proj\"]`. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "50e1ab3a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Number of trainable parameters: 786432\n" | |
] | |
} | |
], | |
"source": [ | |
"# Freeze the base model\n", | |
"model.freeze()\n", | |
"\n", | |
"# Convert linear layers to lora layers\n", | |
"linear_to_lora_layers(model, lora_config[\"lora_layers\"], lora_config[\"lora_parameters\"])\n", | |
"\n", | |
"num_train_params = (\n", | |
" sum(v.size for _, v in tree_flatten(model.trainable_parameters()))\n", | |
")\n", | |
"print(f\"Number of trainable parameters: {num_train_params}\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "827d1590", | |
"metadata": {}, | |
"source": [ | |
"Now we're ready to put it all together and actually train the model. We'll use `Adam` for the optimizer, but you can specify any [optimizer](https://ml-explore.github.io/mlx/build/html/python/optimizers/common_optimizers.html) with any [scheduler](https://ml-explore.github.io/mlx/build/html/python/optimizers/schedulers.html). We also added a custom class to capture the training and validation loss to plot it later." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "984516d3", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Starting training..., iters: 200\n", | |
"Iter 1: Val loss 2.687, Val took 32.210s\n", | |
"Iter 10: Train loss 2.456, Learning Rate 1.000e-05, It/sec 0.511, Tokens/sec 583.542, Trained Tokens 11412, Peak mem 14.235 GB\n", | |
"Iter 20: Train loss 2.328, Learning Rate 1.000e-05, It/sec 0.708, Tokens/sec 605.538, Trained Tokens 19963, Peak mem 14.235 GB\n", | |
"Iter 30: Train loss 2.062, Learning Rate 1.000e-05, It/sec 0.621, Tokens/sec 625.616, Trained Tokens 30034, Peak mem 14.235 GB\n", | |
"Iter 40: Train loss 1.856, Learning Rate 1.000e-05, It/sec 0.577, Tokens/sec 613.944, Trained Tokens 40670, Peak mem 14.235 GB\n", | |
"Iter 50: Val loss 1.808, Val took 30.444s\n", | |
"Iter 50: Train loss 1.872, Learning Rate 1.000e-05, It/sec 4.770, Tokens/sec 5100.991, Trained Tokens 51365, Peak mem 14.235 GB\n", | |
"Iter 60: Train loss 1.904, Learning Rate 1.000e-05, It/sec 0.588, Tokens/sec 640.415, Trained Tokens 62257, Peak mem 14.235 GB\n", | |
"Iter 70: Train loss 1.737, Learning Rate 1.000e-05, It/sec 0.603, Tokens/sec 648.153, Trained Tokens 73010, Peak mem 14.235 GB\n", | |
"Iter 80: Train loss 1.810, Learning Rate 1.000e-05, It/sec 0.657, Tokens/sec 638.259, Trained Tokens 82730, Peak mem 14.235 GB\n", | |
"Iter 90: Train loss 1.652, Learning Rate 1.000e-05, It/sec 0.667, Tokens/sec 631.938, Trained Tokens 92203, Peak mem 14.235 GB\n", | |
"Iter 100: Val loss 1.756, Val took 28.710s\n", | |
"Iter 100: Train loss 1.811, Learning Rate 1.000e-05, It/sec 3.804, Tokens/sec 3831.672, Trained Tokens 102276, Peak mem 14.235 GB\n", | |
"Iter 100: Saved adapter weights to adapters/adapters.safetensors and adapters/0000100_adapters.safetensors.\n", | |
"Iter 110: Train loss 1.770, Learning Rate 1.000e-05, It/sec 0.678, Tokens/sec 610.968, Trained Tokens 111292, Peak mem 14.235 GB\n", | |
"Iter 120: Train loss 1.696, Learning Rate 1.000e-05, It/sec 0.734, Tokens/sec 631.704, Trained Tokens 119904, Peak mem 14.235 GB\n", | |
"Iter 130: Train loss 1.719, Learning Rate 1.000e-05, It/sec 0.729, Tokens/sec 636.505, Trained Tokens 128640, Peak mem 14.235 GB\n", | |
"Iter 140: Train loss 1.673, Learning Rate 1.000e-05, It/sec 0.660, Tokens/sec 653.183, Trained Tokens 138543, Peak mem 14.235 GB\n", | |
"Iter 150: Val loss 1.645, Val took 28.133s\n", | |
"Iter 150: Train loss 1.694, Learning Rate 1.000e-05, It/sec 5.227, Tokens/sec 5618.089, Trained Tokens 149291, Peak mem 14.235 GB\n", | |
"Iter 160: Train loss 1.691, Learning Rate 1.000e-05, It/sec 0.800, Tokens/sec 608.860, Trained Tokens 156906, Peak mem 14.235 GB\n", | |
"Iter 170: Train loss 1.653, Learning Rate 1.000e-05, It/sec 0.848, Tokens/sec 640.269, Trained Tokens 164458, Peak mem 14.235 GB\n", | |
"Iter 180: Train loss 1.638, Learning Rate 1.000e-05, It/sec 0.516, Tokens/sec 644.267, Trained Tokens 176950, Peak mem 14.235 GB\n", | |
"Iter 190: Train loss 1.579, Learning Rate 1.000e-05, It/sec 0.666, Tokens/sec 627.337, Trained Tokens 186373, Peak mem 14.235 GB\n", | |
"Iter 200: Val loss 1.758, Val took 31.519s\n", | |
"Iter 200: Train loss 1.658, Learning Rate 1.000e-05, It/sec 6.545, Tokens/sec 5317.109, Trained Tokens 194497, Peak mem 14.235 GB\n", | |
"Iter 200: Saved adapter weights to adapters/adapters.safetensors and adapters/0000200_adapters.safetensors.\n", | |
"Saved final adapter weights to adapters/adapters.safetensors.\n" | |
] | |
} | |
], | |
"source": [ | |
"# Put the model in training mode:\n", | |
"model.train()\n", | |
"\n", | |
"# Make the optimizer:\n", | |
"opt = optim.Adam(learning_rate=1e-5)\n", | |
"\n", | |
"# Make a class to record the training stats:\n", | |
"class Metrics:\n", | |
" train_losses = []\n", | |
" val_losses = []\n", | |
" def on_train_loss_report(self, info):\n", | |
" self.train_losses.append((info[\"iteration\"], info[\"train_loss\"]))\n", | |
" def on_val_loss_report(self, info):\n", | |
" self.val_losses.append((info[\"iteration\"], info[\"val_loss\"]))\n", | |
"\n", | |
"metrics = Metrics()\n", | |
"\n", | |
"# Train model:\n", | |
"train(\n", | |
" model=model,\n", | |
" tokenizer=tokenizer,\n", | |
" args=training_args,\n", | |
" optimizer=opt,\n", | |
" train_dataset=train_set,\n", | |
" val_dataset=valid_set,\n", | |
" training_callback=metrics,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b8d043b8", | |
"metadata": {}, | |
"source": [ | |
"The adapters are saved every 100 iterations along with the final adapters in `adapters.safetensors`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ac329358", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!ls adapters/" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2b7e23ee", | |
"metadata": {}, | |
"source": [ | |
"Next, let's plot the training and validation losses to see how well the adapters fit the data." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "f1ffd638", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"train_its, train_losses = zip(*metrics.train_losses)\n", | |
"val_its, val_losses = zip(*metrics.val_losses)\n", | |
"plt.plot(train_its, train_losses, '-o')\n", | |
"plt.plot(val_its, val_losses, '-o')\n", | |
"plt.xlabel(\"Iteration\")\n", | |
"plt.ylabel(\"Loss\")\n", | |
"plt.legend(['Train', \"Valid\"]);" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b28f216c", | |
"metadata": {}, | |
"source": [ | |
"### Evaluate\n", | |
"\n", | |
"The training and validation loss are only part of the story. For HellaSwag, we ultimately care about how good the model is at answering questions. To asses this, let's generate the actual `ending1`, `ending2`, `ending3`, or `ending4` responses with the fine-tuned model and measure the accuracy.\n", | |
"\n", | |
"First, let's split the last word off of each example in the test set to create a prompt without the answer." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "d96e4dcf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_set = [t.rsplit(\" \", maxsplit=1) for t in test_set]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8becd26a", | |
"metadata": {}, | |
"source": [ | |
"Next, we'll generate the response for each example in the test set and compare it to the ground-truth answer to measure the accuracy." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b396980a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Increase this number to use more test examples\n", | |
"num_test = 100\n", | |
"num_correct = 0\n", | |
"for prompt, answer in tqdm.tqdm(test_set[:num_test]):\n", | |
" response = generate(model, tokenizer, prompt, max_tokens=2)\n", | |
" num_correct += (response==answer)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "4cbc00b3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Approximate test accuracy 0.750\n" | |
] | |
} | |
], | |
"source": [ | |
"test_acc = num_correct / num_test\n", | |
"print(f\"Approximate test accuracy {test_acc:.3f}\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "67fbba7f", | |
"metadata": {}, | |
"source": [ | |
"### Fuse Adapters\n", | |
"\n", | |
"Sometimes its convenient to fuse the adapters into the base model to create a single adapted model. MLX LM has a fuse script just for that.\n", | |
"\n", | |
"The adapted weights are: $\\tilde{W} = W + c \\cdot \\mathbf{b}^\\top \\mathbf{a}$. Note, this process can be destructive if the inputs are in low precision and they have very different magnitudes. Tuning the `scale` parameter, $c$, prior to fine-tuning can improve the model performance after fusion.\n", | |
"\n", | |
"To see more options for fusing the model, including how to upload to HuggingFace [check the documentation](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#fuse)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"id": "37854c9b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!mlx_lm.fuse --model {model_path}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c349707e", | |
"metadata": {}, | |
"source": [ | |
"Once the adapters are fused, we can rerun the evaluation using the fused model to make sure it worked. By default the fused model will be saved to `lora_fused_model`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"id": "c1c45e3a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Approximate test accuracy 0.750\n" | |
] | |
} | |
], | |
"source": [ | |
"model, tokenizer = load(\"lora_fused_model\")\n", | |
"num_correct = 0\n", | |
"for prompt, answer in tqdm.tqdm(test_set[:num_test]):\n", | |
" response = generate(model, tokenizer, prompt, max_tokens=2)\n", | |
" num_correct += (response==answer)\n", | |
"test_acc = num_correct / num_test\n", | |
"print(f\"Approximate test accuracy {test_acc:.3f}\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d0dc7f4c", | |
"metadata": {}, | |
"source": [ | |
"### Troubleshooting\n", | |
"\n", | |
"#### Results\n", | |
"\n", | |
"To figure out why your LoRA adapters are not working well it's critical to plot both the trianing loss and validation loss over the duration of fine-tuning. There are really only two cases to consider: underfitting or overfitting. And you can figure out which regime you are in based on the above plot.\n", | |
"\n", | |
"**Underfitting**: The trianing loss is not low enough and the validation loss closely matches the training loss. You could also measure the accuracy on the training set itself for question-answering style tasks like HellaSwag. If you are in this regime you have a few options to improve the results:\n", | |
"\n", | |
"- Use more adapters. Increase `lora_layers` or adapt more of the linear layers within a given block by setting `lora_parameters[\"keys\"]`.\n", | |
"- Use a higher rank. A higher rank means more parameters per adapter.\n", | |
"- If you are using dropout, decrease the droupout rate or turn it off entirely.\n", | |
"- Sometimes, underfitting issues are really optimization issues. In these cases it can be helpful to tune the learning rate or learning rate schedule.\n", | |
"- If none of the above works, try a bigger model. For example, try Phi-3 medium instead of Phi-3 tiny.\n", | |
"\n", | |
"**Overfitting**: The trianing loss keeps going down but the validation loss stops going down and even starts to go up. If you are in this regime you also have a few options:\n", | |
"\n", | |
"- The best thing to do is to use more trianing data if you have it.\n", | |
"- Contrary to the underfitting regime decreasing the capacity of the model can help. For example, use fewer adapters, a lower LoRA rank, or a smaller model size.\n", | |
"- If you are not using dropout, use it.\n", | |
"\n", | |
"If you find your adapters work well pre-fusion but stop working post-fusion, try tuning the `scale` parameter, $c$, prior to fine-tuning. Typically the adapters have a smaller magnitude than the weights, so using a larger scale helps.\n", | |
"\n", | |
"#### Memory Use\n", | |
"\n", | |
"Fine-tuning a large LM with LoRA requires a machine with a decent amount of memory. Here are some tips to reduce memory use should you need to do so. \n", | |
"\n", | |
"- Try quantization (QLoRA). You can use QLoRA by generating a quantized model with `mlx_lm.convert` and the `-q` flag or by using an already quantized model from HuggingFace.\n", | |
"\n", | |
"- Try using a smaller batch size. You can set the `batch_size` parameter in the `TrainingArgs` or pass `--batch-size` if you are using the CLI. The default is 4 so setting this to 2 or 1 will reduce memory consumption. Note, this may slow things down a little..\n", | |
"\n", | |
"- Reduce the number of layers to fine-tune with by setting `lora_layers` to a smaller value or passing `--lora-layers` if you are using the CLI. The default is `16`, so you can try `8` or `4`. This reduces the amount of memory needed for back propagation. It may also reduce the quality of the fine-tuned model and you may need to compensate with a larger `rank`.\n", | |
"\n", | |
"- Longer examples require more memory. If it makes sense for your data, one thing you can do is break your examples into smaller sequences when making the `train`, `valid`, and `test` data sets.\n", | |
"\n", | |
"- Gradient checkpointing lets you trade-off memory use (less) for computation (more) by recomputing instead of storing intermediate values needed by the backward pass. You can use gradient checkpointing by passing `grad_checkpoint=True` to the `TrainingArgs` or the `--grad-checkpoint` flag if using the CLI. Gradient checkpointing will be more helpful for larger batch sizes or sequence lengths with smaller or quantized models.\n", | |
"\n", | |
"### Next Steps\n", | |
"\n", | |
"- To learn more about MLX check-out the [GitHub repo](http://github.com/ml-explore/mlx) and [documentation](https://ml-explore.github.io/mlx/)\n", | |
"- For more on MLX LM check-out the [MLX LM documentation](https://github.com/ml-explore/mlx-examples/tree/main/llms#readme).\n", | |
"- Check out the other [MLX Examples](https://github.com/ml-explore/mlx-examples/tree/main). These are great as a learning resource or to use as a starting point for a new project.\n", | |
"- We also have an example of [LoRA fine-tuning in MLX Swift](https://github.com/ml-explore/mlx-swift-examples/tree/main/Applications/LoRATrainingExample)." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.17" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment