Last active
May 3, 2024 00:39
-
-
Save dwarkeshsp/da095a349fa9fd428d10ba08f6aba19e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import pandas as pd\n", | |
"from tqdm import tqdm\n", | |
"from transformers import (\n", | |
" AutoTokenizer,\n", | |
" BitsAndBytesConfig,\n", | |
" MixtralModel,\n", | |
" MixtralConfig,\n", | |
" AutoModelForCausalLM\n", | |
")\n", | |
"from torch import nn\n", | |
"import collections\n", | |
"import time\n", | |
"import os\n", | |
"from einops import rearrange\n", | |
"import torch.nn.functional as F" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"cuda\n" | |
] | |
} | |
], | |
"source": [ | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"print(device)\n", | |
"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mixtral-8x7B-v0.1\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"layer = 15\n", | |
"\n", | |
"W2 = 0\n", | |
"GATE = 1\n", | |
"\n", | |
"sequence_length = 128\n", | |
"batch_size = 6\n", | |
"num_experts_per_token = 2\n", | |
"num_experts = 8\n", | |
"expert_ffn_dims = 14336\n", | |
"expert_hidden_dims = 4096" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"`low_cpu_mem_usage` was None, now set to True since model is quantized.\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "b31d9d112a1745b9814563e9e0986012", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/19 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"double_quant_config = BitsAndBytesConfig(\n", | |
" load_in_4bit=True,\n", | |
" bnb_4bit_use_double_quant=True,\n", | |
")\n", | |
"\n", | |
"config = MixtralConfig(\n", | |
" num_experts_per_tok=num_experts_per_token,\n", | |
" num_hidden_layers=layer + 1,\n", | |
")\n", | |
"\n", | |
"model = AutoModelForCausalLM.from_pretrained(\n", | |
" \"mistralai/Mixtral-8x7B-v0.1\",\n", | |
" quantization_config=double_quant_config,\n", | |
" attn_implementation=\"flash_attention_2\",\n", | |
" config=config,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def load_generated_dataset(\n", | |
" filename, batch_size, dataset_relative_path=\"../dataset\"\n", | |
"):\n", | |
" dataset = pd.read_csv(f\"{dataset_relative_path}/{filename}\")\n", | |
" # print(dataset.head(10))\n", | |
" dataset_list = dataset[\"0\"].tolist()\n", | |
" dataloader = torch.utils.data.DataLoader(dataset_list, batch_size=batch_size)\n", | |
" return dataloader\n", | |
"\n", | |
"\n", | |
"dataset = load_generated_dataset(\"pile.csv\", batch_size=batch_size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 0%| | 1/1939 [00:00<14:12, 2.27it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([8, 768, 4096])\n", | |
"torch.Size([768, 32768])\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"i = 0\n", | |
"\n", | |
"moe = model.model.layers[layer].block_sparse_moe\n", | |
"\n", | |
"hooks = []\n", | |
"\n", | |
"for batch in tqdm(dataset):\n", | |
" if i > 0:\n", | |
" break\n", | |
" i += 1\n", | |
"\n", | |
" mlps = [None] * num_experts\n", | |
" router_logits = None\n", | |
"\n", | |
" def getActivation(expert_idx, type):\n", | |
"\n", | |
" def hook(model, input, output):\n", | |
" global router_logits\n", | |
"\n", | |
" if type == W2:\n", | |
" mlps[expert_idx] = output.detach()\n", | |
" elif type == GATE:\n", | |
" router_logits = output.detach()\n", | |
"\n", | |
" return hook\n", | |
"\n", | |
" for expert_idx in range(num_experts):\n", | |
"\n", | |
" w2_hook = moe.experts[expert_idx].w2.register_forward_hook(\n", | |
" getActivation(expert_idx, W2)\n", | |
" )\n", | |
" gate_hook = moe.gate.register_forward_hook(getActivation(expert_idx, GATE))\n", | |
" hooks.extend([w2_hook, gate_hook])\n", | |
"\n", | |
" tokenizer.pad_token = tokenizer.eos_token\n", | |
"\n", | |
" batch_tokens = tokenizer(\n", | |
" batch,\n", | |
" padding=\"max_length\",\n", | |
" truncation=True,\n", | |
" max_length=sequence_length,\n", | |
" return_tensors=\"pt\",\n", | |
" ).to(device)\n", | |
"\n", | |
" try:\n", | |
" output = model(batch_tokens[\"input_ids\"])\n", | |
" except Exception as e:\n", | |
" print(\"Exception occured, removing hooks\")\n", | |
" for hook in hooks:\n", | |
" hook.remove()\n", | |
"\n", | |
" routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n", | |
" routing_weights, selected_experts = torch.topk(routing_weights, num_experts_per_token, dim=-1)\n", | |
" routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n", | |
"\n", | |
" final_hidden_states = torch.zeros(\n", | |
" (batch_size * sequence_length, expert_hidden_dims),\n", | |
" device=device,\n", | |
" )\n", | |
"\n", | |
" expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=8).permute(\n", | |
" 2, 1, 0\n", | |
" )\n", | |
"\n", | |
" activations = torch.zeros(\n", | |
" num_experts, batch_size * sequence_length, expert_hidden_dims, device=device\n", | |
" )\n", | |
"\n", | |
" for expert_idx in range(num_experts):\n", | |
"\n", | |
" idx, top_x = torch.where(expert_mask[expert_idx])\n", | |
"\n", | |
" if top_x.shape[0] == 0:\n", | |
" continue\n", | |
"\n", | |
" top_x_list = top_x.tolist()\n", | |
" idx_list = idx.tolist()\n", | |
"\n", | |
" current_hidden_states = (\n", | |
" mlps[expert_idx] * routing_weights[top_x_list, idx_list, None]\n", | |
" )\n", | |
" activations[expert_idx][top_x] = current_hidden_states\n", | |
"\n", | |
" activations = rearrange(activations, 'experts sequences hidden -> sequences (experts hidden)')\n", | |
" print(activations.shape)\n", | |
"\n", | |
" for hook in hooks:\n", | |
" hook.remove()\n", | |
"\n", | |
" # print(tokenizer.batch_decode(generated_ids))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t1" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment