Skip to content

Instantly share code, notes, and snippets.

@dwarkeshsp
Last active May 3, 2024 00:39
Show Gist options
  • Save dwarkeshsp/da095a349fa9fd428d10ba08f6aba19e to your computer and use it in GitHub Desktop.
Save dwarkeshsp/da095a349fa9fd428d10ba08f6aba19e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"from transformers import (\n",
" AutoTokenizer,\n",
" BitsAndBytesConfig,\n",
" MixtralModel,\n",
" MixtralConfig,\n",
" AutoModelForCausalLM\n",
")\n",
"from torch import nn\n",
"import collections\n",
"import time\n",
"import os\n",
"from einops import rearrange\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mixtral-8x7B-v0.1\")"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"layer = 15\n",
"\n",
"W2 = 0\n",
"GATE = 1\n",
"\n",
"sequence_length = 128\n",
"batch_size = 6\n",
"num_experts_per_token = 2\n",
"num_experts = 8\n",
"expert_ffn_dims = 14336\n",
"expert_hidden_dims = 4096"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"`low_cpu_mem_usage` was None, now set to True since model is quantized.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b31d9d112a1745b9814563e9e0986012",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/19 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"double_quant_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_use_double_quant=True,\n",
")\n",
"\n",
"config = MixtralConfig(\n",
" num_experts_per_tok=num_experts_per_token,\n",
" num_hidden_layers=layer + 1,\n",
")\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" \"mistralai/Mixtral-8x7B-v0.1\",\n",
" quantization_config=double_quant_config,\n",
" attn_implementation=\"flash_attention_2\",\n",
" config=config,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def load_generated_dataset(\n",
" filename, batch_size, dataset_relative_path=\"../dataset\"\n",
"):\n",
" dataset = pd.read_csv(f\"{dataset_relative_path}/{filename}\")\n",
" # print(dataset.head(10))\n",
" dataset_list = dataset[\"0\"].tolist()\n",
" dataloader = torch.utils.data.DataLoader(dataset_list, batch_size=batch_size)\n",
" return dataloader\n",
"\n",
"\n",
"dataset = load_generated_dataset(\"pile.csv\", batch_size=batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 1/1939 [00:00<14:12, 2.27it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 768, 4096])\n",
"torch.Size([768, 32768])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"i = 0\n",
"\n",
"moe = model.model.layers[layer].block_sparse_moe\n",
"\n",
"hooks = []\n",
"\n",
"for batch in tqdm(dataset):\n",
" if i > 0:\n",
" break\n",
" i += 1\n",
"\n",
" mlps = [None] * num_experts\n",
" router_logits = None\n",
"\n",
" def getActivation(expert_idx, type):\n",
"\n",
" def hook(model, input, output):\n",
" global router_logits\n",
"\n",
" if type == W2:\n",
" mlps[expert_idx] = output.detach()\n",
" elif type == GATE:\n",
" router_logits = output.detach()\n",
"\n",
" return hook\n",
"\n",
" for expert_idx in range(num_experts):\n",
"\n",
" w2_hook = moe.experts[expert_idx].w2.register_forward_hook(\n",
" getActivation(expert_idx, W2)\n",
" )\n",
" gate_hook = moe.gate.register_forward_hook(getActivation(expert_idx, GATE))\n",
" hooks.extend([w2_hook, gate_hook])\n",
"\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
" batch_tokens = tokenizer(\n",
" batch,\n",
" padding=\"max_length\",\n",
" truncation=True,\n",
" max_length=sequence_length,\n",
" return_tensors=\"pt\",\n",
" ).to(device)\n",
"\n",
" try:\n",
" output = model(batch_tokens[\"input_ids\"])\n",
" except Exception as e:\n",
" print(\"Exception occured, removing hooks\")\n",
" for hook in hooks:\n",
" hook.remove()\n",
"\n",
" routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n",
" routing_weights, selected_experts = torch.topk(routing_weights, num_experts_per_token, dim=-1)\n",
" routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n",
"\n",
" final_hidden_states = torch.zeros(\n",
" (batch_size * sequence_length, expert_hidden_dims),\n",
" device=device,\n",
" )\n",
"\n",
" expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=8).permute(\n",
" 2, 1, 0\n",
" )\n",
"\n",
" activations = torch.zeros(\n",
" num_experts, batch_size * sequence_length, expert_hidden_dims, device=device\n",
" )\n",
"\n",
" for expert_idx in range(num_experts):\n",
"\n",
" idx, top_x = torch.where(expert_mask[expert_idx])\n",
"\n",
" if top_x.shape[0] == 0:\n",
" continue\n",
"\n",
" top_x_list = top_x.tolist()\n",
" idx_list = idx.tolist()\n",
"\n",
" current_hidden_states = (\n",
" mlps[expert_idx] * routing_weights[top_x_list, idx_list, None]\n",
" )\n",
" activations[expert_idx][top_x] = current_hidden_states\n",
"\n",
" activations = rearrange(activations, 'experts sequences hidden -> sequences (experts hidden)')\n",
" print(activations.shape)\n",
"\n",
" for hook in hooks:\n",
" hook.remove()\n",
"\n",
" # print(tokenizer.batch_decode(generated_ids))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"t1"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment