Last active
July 29, 2023 08:45
-
-
Save CoffeeVampir3/6156f4ba083848f59ff41b96f7ff0988 to your computer and use it in GitHub Desktop.
Lora Merging Investigation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "12f39541-04a5-4e91-84c2-e798cdb43b53", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
| "from peft import PeftModel\n", | |
| "import torch\n", | |
| "import lovely_tensors as lt\n", | |
| "lt.monkey_patch()\n", | |
| "\n", | |
| "import os\n", | |
| "\n", | |
| "def merge_lora(base_model_path, lora_path, output_dir):\n", | |
| " device_arg = { 'device_map': 'auto' }\n", | |
| " \n", | |
| " base_model = AutoModelForCausalLM.from_pretrained(\n", | |
| " base_model_path,\n", | |
| " return_dict=True,\n", | |
| " torch_dtype=torch.float16,\n", | |
| " **device_arg\n", | |
| " )\n", | |
| "\n", | |
| " print(f\"Loading PEFT: {lora_path}\")\n", | |
| " model = PeftModel.from_pretrained(base_model, lora_path, **device_arg)\n", | |
| " print(f\"Running merge_and_unload\")\n", | |
| " model.merge_adapter()#model = model.merge_and_unload()\n", | |
| "\n", | |
| " tokenizer = AutoTokenizer.from_pretrained(base_model_path)\n", | |
| "\n", | |
| " #model.save_pretrained(output_dir, use_safetensors=True)\n", | |
| " #tokenizer.save_pretrained(output_dir)\n", | |
| " print(f\"Model saved to {output_dir}\")\n", | |
| " return model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "ea701e52-45b3-4335-9886-37b4a57f78dd", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#partial loading experiment\n", | |
| "from safetensors.torch import load_file as safe_load_file\n", | |
| "import torch\n", | |
| "#adapters_weights = safe_load_file(\"loras/LIMARP-Llama2-LoRA-adapter-13B\", device=\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
| "\n", | |
| "adapters_weights = torch.load(\"loras/LIMARP-Llama2-LoRA-adapter-13B/adapter_model.bin\", map_location=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "dd3da770-aca1-482a-ab56-e0ee428d8f87", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#print(adapters_weights)\n", | |
| "for k,v in adapters_weights.items():\n", | |
| " print(k, v)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "0388a4c4-71b6-4eed-ad11-1fa96b1be0a8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "base = \"models/Nous-Hermes-Llama2-13b-Storyteller\"\n", | |
| "lora = \"loras/LIMARP-Llama2-LoRA-adapter-13B\"\n", | |
| "out = \"merged-lima\"\n", | |
| "#model = merge_lora(base, lora, out)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "75ac8559-fa58-461f-8560-63c6178c0761", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pickle\n", | |
| "with open('lora_theirmerge.pkl', 'wb') as f:\n", | |
| " pickle.dump(model.state_dict(), f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "98077b18-297b-40bb-8984-0700239c11f8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#test pickle:\n", | |
| "device_arg = { 'device_map': 'auto' }\n", | |
| "merged_lima = AutoModelForCausalLM.from_pretrained(\n", | |
| " \"merged-lima\",\n", | |
| " return_dict=True,\n", | |
| " torch_dtype=torch.float16,\n", | |
| " **device_arg\n", | |
| ")\n", | |
| "\n", | |
| "import pickle\n", | |
| "\n", | |
| "merge_weights = [merged_lima.model.layers[i].self_attn.q_proj.weight.data for i in [0, 5, 19, 33]]\n", | |
| "layer_indices = [0, 5, 19, 33]\n", | |
| "\n", | |
| "# We create a dictionary where the keys are the layer indices and the values are the tensors\n", | |
| "weights_dict = {index: tensor for index, tensor in zip(layer_indices, merge_weights)}\n", | |
| "\n", | |
| "# We open a file called 'weights.pkl' in write-binary mode ('wb') and use pickle.dump() to serialize the dictionary to the file\n", | |
| "with open('weights.pkl', 'wb') as f:\n", | |
| " pickle.dump(weights_dict, f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "3225c9b4-8c3c-4e85-a102-4efc9e7c7d0e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "device_arg = { 'device_map': 'auto' }\n", | |
| "\n", | |
| "base_model = AutoModelForCausalLM.from_pretrained(\n", | |
| " base,\n", | |
| " return_dict=True,\n", | |
| " torch_dtype=torch.float16,\n", | |
| " **device_arg\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "9261de84-daa4-451d-9001-f856bc2e1746", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "layer_dict = {}\n", | |
| "length = 0\n", | |
| "layers = 0\n", | |
| "\n", | |
| "for name, param in base_model.named_parameters():\n", | |
| " # Extract the layer index from the parameter name\n", | |
| " length += 1\n", | |
| " old_layers = layers\n", | |
| " if 'layers.' in name:\n", | |
| " layer_index = int(name.split('.')[2]) # Adjust this depending on the exact format of the names\n", | |
| " if layer_index not in layer_dict:\n", | |
| " layer_dict[layer_index] = []\n", | |
| " layer_dict[layer_index].append((name, param.shape))\n", | |
| " layers += 1\n", | |
| " if old_layers == layers:\n", | |
| " print(name)\n", | |
| "\n", | |
| "print(length)\n", | |
| "print(layers)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "da80121d-7411-486e-9ffd-72ca3b1b8759", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "lora_model = PeftModel.from_pretrained(base_model, lora, **device_arg)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "4a9ddc7b-775a-4faf-8425-d75f238b505c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "print(lora_model.peft_config)\n", | |
| "r = lora_model.peft_config[\"default\"].r\n", | |
| "alpha = lora_model.peft_config[\"default\"].lora_alpha\n", | |
| "\n", | |
| "scaling = alpha/r\n", | |
| "print(scaling)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "9c92c54b-12c9-4878-a844-31400d23b524", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#test pointers\n", | |
| "lcount = 0\n", | |
| "\n", | |
| "zipped = zip(base_model.named_parameters(), lora_model.named_parameters())\n", | |
| "for ((bm, bp), (lm, lp)) in zipped:\n", | |
| " #print(bm, lm)\n", | |
| " lcount += 1\n", | |
| "\n", | |
| "assert(base_model.model.layers[0].self_attn.q_proj.weight is lora_model.model.model.layers[0].self_attn.q_proj.weight)\n", | |
| "\n", | |
| "old_test = base_model.model.layers[0].self_attn.q_proj.weight.data.clone()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "d8bb683a-e2e9-4139-a70d-dc83f9f8b9d0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "torch.allclose(old_test, base_model.model.layers[0].self_attn.q_proj.weight)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "968a8074-5b4f-46ac-9934-be490196e578", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "count = 0\n", | |
| "for name, param in lora_model.named_parameters():\n", | |
| " count+=1\n", | |
| "print(count)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "33413d83-0b9a-4178-bb7b-57e84a61a42f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "for name, module in lora_model.named_modules():\n", | |
| " if 'lora_A' in dir(module):\n", | |
| " delattr(module, 'lora_A')\n", | |
| " if 'lora_B' in dir(module):\n", | |
| " delattr(module, 'lora_B')\n", | |
| "\n", | |
| "count = 0\n", | |
| "for name, param in lora_model.named_parameters():\n", | |
| " count+=1\n", | |
| "print(count)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "05fab3b1-624b-460d-855c-b6cfa23edc0b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "8b2a888f-9d88-4a5c-b195-89126e564ac9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "weights_list = []\n", | |
| "\n", | |
| "# Loop over all parameters\n", | |
| "for name, param in lora_model.named_parameters():\n", | |
| " # If the parameter name ends with '.weight', it's an original weight\n", | |
| " if name.endswith('.weight'):\n", | |
| " # Make sure it's not a lora_A or lora_B weight\n", | |
| " if not any(substring in name for substring in ['lora_A', 'lora_B']):\n", | |
| " # Construct the names of the corresponding lora_A and lora_B weights\n", | |
| " layers = name.split('.')\n", | |
| " try:\n", | |
| " layer = lora_model\n", | |
| " for item in layers[:-1]: # We go until the penultimate item (excluding the 'weight' part)\n", | |
| " if 'lora' in item: # Split further if lora_A or lora_B\n", | |
| " item, lora_item = item.split('_')\n", | |
| " layer = getattr(layer, item)\n", | |
| " layer = getattr(layer, lora_item)\n", | |
| " else:\n", | |
| " layer = getattr(layer, item)\n", | |
| "\n", | |
| " # Try to get lora_A and lora_B weights\n", | |
| " lora_A = getattr(layer, 'lora_A').default.weight\n", | |
| " lora_B = getattr(layer, 'lora_B').default.weight\n", | |
| "\n", | |
| " # Add a tuple to the list with the parameter name as the first item\n", | |
| " weights_list.append((name, param.data, lora_A, lora_B))\n", | |
| "\n", | |
| " except AttributeError:\n", | |
| " pass\n", | |
| " #print(f\"Unable to find lora_A or lora_B weights for {name}\")\n", | |
| "\n", | |
| "for (name,weight,a,b) in weights_list:\n", | |
| " ab = b @ a\n", | |
| " weight += ab * scaling\n", | |
| " print(f\"Did thing for layer named {name}\")\n", | |
| "\n", | |
| "#a = lora_model.model.model.layers[0].self_attn.q_proj.lora_A.default.weight\n", | |
| "#b = lora_model.model.model.layers[0].self_attn.q_proj.lora_B.default.weight\n", | |
| "#print(a)\n", | |
| "#print(b)\n", | |
| "\n", | |
| "#c = a@b\n", | |
| "#print(c)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "665f40f4-2f3c-44f2-9040-e4de7acbedf9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pickle\n", | |
| "with open('lora_mymerge.pkl', 'wb') as f:\n", | |
| " pickle.dump(lora_model.state_dict(), f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "29959000-607e-42ab-814d-b9769f3fa0a7", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pickle\n", | |
| "my_method = lora_model.model.model.layers[5].self_attn.q_proj.weight.data.to(\"cuda:0\")\n", | |
| "#load and test pickle\n", | |
| "with open('weights.pkl', 'rb') as f:\n", | |
| " loaded_weights = pickle.load(f)\n", | |
| "\n", | |
| " their_method = loaded_weights[5]\n", | |
| " print(their_method)\n", | |
| " print(my_method)\n", | |
| " print(torch.allclose(their_method, my_method, atol=1e-04, rtol=1e-04))\n", | |
| "\n", | |
| " close = torch.isclose(their_method, my_method, equal_nan=True)\n", | |
| " not_close = ~close # This gives you a Boolean tensor that is True where the values aren't close\n", | |
| " \n", | |
| " # Now, you can use not_close to index into their_method and my_method and find the values that aren't close\n", | |
| " their_method_not_close = their_method[not_close]\n", | |
| " my_method_not_close = my_method[not_close]\n", | |
| " \n", | |
| " print(\"Their method values that aren't close: \", their_method_not_close)\n", | |
| " print(\"My method values that aren't close: \", my_method_not_close)\n", | |
| " \n", | |
| " mse = torch.mean((their_method - my_method) ** 2)\n", | |
| " print(f'MSE: {mse.item():.10f}')\n", | |
| "\n", | |
| " mae = torch.mean(torch.abs(their_method - my_method))\n", | |
| " print(f'MAE: {mae.item():.10f}')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "e6f04285-7ddc-484f-a571-46e1eb5bb879", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#test pickle:\n", | |
| "import pickle\n", | |
| "\n", | |
| "merge_weights = [lora_model.model.model.layers[i].self_attn.q_proj.weight.data for i in [0, 5, 19, 33]]\n", | |
| "layer_indices = [0, 5, 19, 33]\n", | |
| "\n", | |
| "# We create a dictionary where the keys are the layer indices and the values are the tensors\n", | |
| "weights_dict = {index: tensor for index, tensor in zip(layer_indices, merge_weights)}\n", | |
| "\n", | |
| "# We open a file called 'weights.pkl' in write-binary mode ('wb') and use pickle.dump() to serialize the dictionary to the file\n", | |
| "with open('weights2.pkl', 'wb') as f:\n", | |
| " pickle.dump(weights_dict, f)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "615af4f0-e680-42aa-abdc-110c3d4b4093", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pickle\n", | |
| "#load and test pickle\n", | |
| "with open('weights.pkl', 'rb') as f:\n", | |
| " loaded_weights = pickle.load(f)\n", | |
| "\n", | |
| " with open('weights2.pkl', 'rb') as q:\n", | |
| " loaded_weights2 = pickle.load(q)\n", | |
| "\n", | |
| " their_method = loaded_weights[19].to(\"cuda:0\")\n", | |
| " my_method = loaded_weights2[19].to(\"cuda:0\")\n", | |
| " print(torch.allclose(their_method, my_method, atol=1e-04, rtol=1e-04))\n", | |
| " \n", | |
| " close = torch.isclose(their_method, my_method, equal_nan=True)\n", | |
| " not_close = ~close # This gives you a Boolean tensor that is True where the values aren't close\n", | |
| " \n", | |
| " # Now, you can use not_close to index into their_method and my_method and find the values that aren't close\n", | |
| " their_method_not_close = their_method[not_close]\n", | |
| " my_method_not_close = my_method[not_close]\n", | |
| " \n", | |
| " print(\"Their method values that aren't close: \", their_method_not_close)\n", | |
| " print(\"My method values that aren't close: \", my_method_not_close)\n", | |
| " \n", | |
| " mse = torch.mean((their_method - my_method) ** 2)\n", | |
| " print(f'MSE: {mse.item():.10f}')\n", | |
| " \n", | |
| " mae = torch.mean(torch.abs(their_method - my_method))\n", | |
| " print(f'MAE: {mae.item():.10f}')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "a2e56862-950a-4224-b85b-872d2e00e781", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def merge(self):\n", | |
| " if self.active_adapter not in self.lora_A.keys():\n", | |
| " return\n", | |
| " if self.merged:\n", | |
| " warnings.warn(\"Already merged. Nothing to do.\")\n", | |
| " return\n", | |
| " if self.r[self.active_adapter] > 0:\n", | |
| " self.weight.data += self.get_delta_weight(self.active_adapter)\n", | |
| " self.merged = True\n", | |
| "\n", | |
| "def get_delta_weight(self, adapter):\n", | |
| " return (\n", | |
| " transpose(\n", | |
| " self.lora_B[adapter].weight @ self.lora_A[adapter].weight,\n", | |
| " self.fan_in_fan_out,\n", | |
| " )\n", | |
| " * self.scaling[adapter]\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "dea2453c-5075-400d-923d-07a106c8493c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pickle\n", | |
| "import torch\n", | |
| "\n", | |
| "# Load the models\n", | |
| "with open('lora_mymerge.pkl', 'rb') as f:\n", | |
| " model1 = pickle.load(f)\n", | |
| " \n", | |
| "with open('lora_theirmerge.pkl', 'rb') as f:\n", | |
| " model2 = pickle.load(f)\n", | |
| "\n", | |
| "difference = False\n", | |
| "# Go through each parameter in the state dictionaries\n", | |
| "for param in model1:\n", | |
| " # Load the corresponding layers from each model\n", | |
| " layer1 = model1[param]\n", | |
| " layer2 = model2[param]\n", | |
| "\n", | |
| " layer1 = layer1.to(\"cuda:0\")\n", | |
| " layer2 = layer2.to(\"cuda:0\")\n", | |
| " \n", | |
| " # Compare the layers\n", | |
| " if not torch.allclose(layer1, layer2):\n", | |
| " print(f\"Difference found in parameter: {param}\")\n", | |
| " difference = True\n", | |
| "\n", | |
| " # Delete the layers to free up memory\n", | |
| " del layer1\n", | |
| " del layer2\n", | |
| " \n", | |
| "for param in model2:\n", | |
| " # Load the corresponding layers from each model\n", | |
| " layer1 = model1[param]\n", | |
| " layer2 = model2[param]\n", | |
| "\n", | |
| " layer1 = layer1.to(\"cuda:0\")\n", | |
| " layer2 = layer2.to(\"cuda:0\")\n", | |
| " \n", | |
| " # Compare the layers\n", | |
| " if not torch.allclose(layer1, layer2):\n", | |
| " print(f\"Difference found in parameter: {param}\")\n", | |
| " difference = True\n", | |
| "\n", | |
| " # Delete the layers to free up memory\n", | |
| " del layer1\n", | |
| " del layer2\n", | |
| "\n", | |
| "if not difference:\n", | |
| " print(\"No differences\")" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.12" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment