Last active
May 24, 2024 13:51
-
-
Save KeremTurgutlu/a99e138e7fca7c9feb6cc9b74394b89e to your computer and use it in GitHub Desktop.
test_triton_mm.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "f7e69d06-de3c-487c-ad62-7aebce775e15", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "04d16c8e-bfba-4e6b-9dd9-58daae15135e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from vllm.model_executor.layers.quantization.triton_mm import triton_mixed_mm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "b349f02c-7df3-4942-861e-523f00e34436", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from hqq.core.quantize import HQQLinear, BaseQuantizeConfig, Quantizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "7335bab7-6909-45cf-b623-6468052940c8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def pack_2xint4(t):\n", | |
" \"\"\"\n", | |
" The packing format is such that consecutive rows are packed into a lower / upper bits\n", | |
" E.g.,\n", | |
" Original, unpacked B (dtype i8):\n", | |
" [\n", | |
" [0, 1, 2, 3]\n", | |
" [4, 5, 6, 7]\n", | |
" [8, 9, 10, 11]\n", | |
" [12, 13, 14, 15]\n", | |
" ]\n", | |
" Packed B:\n", | |
" [\n", | |
" [0|4, 1|5, 2|6, 3|7]\n", | |
" [8|12, 9|13, 10|14, 11|15]\n", | |
" ]\n", | |
" (Note each entry in `Packed B` is shown lsb->msb)\n", | |
" \"\"\"\n", | |
" assert t.dtype == torch.int8 or t.dtype == torch.uint8\n", | |
" t = t.reshape(t.shape[0] // 2, 2, t.shape[1]).permute(1, 0, 2)\n", | |
" return (t[0] & 0xF) | (t[1] << 4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "1ad339f8-d5df-4740-81c8-61f46eba450b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def patch_hqq_to_tritonmm(layer, patch_param):\n", | |
" if(isinstance(layer, HQQLinear)):\n", | |
"\n", | |
" #Handle no grouping case\n", | |
" shape = layer.meta['shape']\n", | |
" layer.group_size = layer.quant_config['weight_quant_params']['group_size']\n", | |
" if(layer.group_size is None):\n", | |
" layer.group_size = shape[1] \n", | |
"\n", | |
" #Update scale/zero\n", | |
" M, N = shape[::-1]\n", | |
" layer.meta ['scale'] = layer.meta ['scale'].reshape(N, -1).T \n", | |
" layer.meta ['zero'] = layer.meta ['zero'].reshape(N, -1).T \n", | |
"\n", | |
" #Repack\n", | |
" layer.W_q.data = pack_2xint4(Quantizer.unpack[layer.meta ['packing']](layer.W_q).reshape(layer.meta [\"shape\"]).T).data \n", | |
"\n", | |
" #Set pred vals\n", | |
" layer.fp8_fast_accum = True #False \n", | |
" layer.kernel_type = \"max_autotune\" #max_autotune\n", | |
"\n", | |
" def matmul_tritonmm(self, x, transpose=True):\n", | |
"\n", | |
" out_dim = self.meta['shape'][0] if (transpose) else self.meta['shape'][1]\n", | |
" out = triton_mixed_mm(x.view([-1, x.shape[-1]]),\n", | |
" self.W_q,\n", | |
" self.meta[\"scale\"],\n", | |
" self.meta[\"zero\"],\n", | |
" group_size=self.group_size,\n", | |
" fp8_fast_accum=self.fp8_fast_accum,\n", | |
" kernel_type=self.kernel_type,\n", | |
" transposed=not transpose,\n", | |
" ).view([x.shape[0], x.shape[1], out_dim])\n", | |
"\n", | |
" return out \n", | |
"\n", | |
" def forward_tritonmm_backprop(self, x):\n", | |
" return HQQMatmulNoCacheMul.apply(x, self.matmul, self.bias)\n", | |
"\n", | |
" def forward_tritonmm_forward(self, x):\n", | |
" out = self.matmul(x)\n", | |
" if(self.bias is not None):\n", | |
" out += self.bias\n", | |
" return out \n", | |
"\n", | |
"\n", | |
" layer.matmul = lambda x, transpose: matmul_tritonmm(layer, x, transpose)\n", | |
" layer.forward = lambda x: forward_tritonmm_backprop(layer, x)\n", | |
"\n", | |
"\n", | |
" return layer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "6ff9ca4b-9ddf-40a5-b185-c3ec886f02ed", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"quant_config = BaseQuantizeConfig(nbits=4,\n", | |
" group_size=64, \n", | |
" quant_zero=False,\n", | |
" quant_scale=False,\n", | |
" offload_meta=False,\n", | |
" view_as_float=False, \n", | |
" axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "8fadd71a-7c24-4eed-a8ad-7c60af6284e6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"q_weight = torch.randn(4096, 4096) # output x input\n", | |
"k_weight = torch.randn(1024, 4096)\n", | |
"v_weight = torch.randn(1024, 4096)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "4d9529d0-8b1d-4b31-9b53-7422f8243136", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dtype = torch.bfloat16\n", | |
"triton_params = {}\n", | |
"for name,p in [(\"q\", q_weight), (\"k\",k_weight), (\"v\",v_weight)]:\n", | |
" m = torch.nn.Linear(*p.T.shape, bias=False)\n", | |
" m.weight.data.copy_(p)\n", | |
" dummy_hqq_linear = HQQLinear(m, quant_config, compute_dtype=dtype)\n", | |
" patched_hqq_linear = patch_hqq_to_tritonmm(dummy_hqq_linear, None)\n", | |
" triton_params[name] = {\"Wq\":patched_hqq_linear.W_q, \n", | |
" \"scale\":patched_hqq_linear.meta['scale'], \n", | |
" \"zero\":patched_hqq_linear.meta['zero']}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"id": "2f7f93a3-5118-4766-9f24-5e769ff3841e", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"qkv_weight = torch.cat([triton_params[k][\"Wq\"] for k in [\"q\", \"k\", \"v\"]], dim=1)\n", | |
"qkv_scale = torch.cat([triton_params[k][\"scale\"] for k in [\"q\", \"k\", \"v\"]], dim=1)\n", | |
"qkv_zero = torch.cat([triton_params[k][\"zero\"] for k in [\"q\", \"k\", \"v\"]], dim=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "e608e1f0-3cac-4dca-b1df-bd669de6e717", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([2048, 6144])" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"qkv_weight.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"id": "c4c025e3-a171-4801-8428-9e588bd516e7", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"x = torch.randn(16,4096, device=\"cuda\", dtype=torch.bfloat16)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"id": "6dea6917-6644-41c5-af61-5f235fa653aa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"output_qkv = triton_mixed_mm(x,\n", | |
" qkv_weight,\n", | |
" qkv_scale,\n", | |
" qkv_zero,\n", | |
" group_size=quant_config['weight_quant_params']['group_size'],\n", | |
" fp8_fast_accum=False,\n", | |
" kernel_type=\"compute_bound\",\n", | |
" transposed=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"id": "4bbc933e-a807-41db-992b-4bec442aadd2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"output_q = triton_mixed_mm(x,\n", | |
" triton_params[\"q\"][\"Wq\"],\n", | |
" triton_params[\"q\"][\"scale\"],\n", | |
" triton_params[\"q\"][\"zero\"],\n", | |
" group_size=quant_config['weight_quant_params']['group_size'],\n", | |
" fp8_fast_accum=False,\n", | |
" kernel_type=\"compute_bound\",\n", | |
" transposed=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"id": "4fbe32cd-c5a4-4ab8-a99c-1d028e4764fa", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"False" | |
] | |
}, | |
"execution_count": 43, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.equal(output_qkv[:,:4096], output_q)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"id": "ce2bc22b-9961-420f-b67b-d07fe309aa3d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"output_k = triton_mixed_mm(x,\n", | |
" triton_params[\"k\"][\"Wq\"],\n", | |
" triton_params[\"k\"][\"scale\"],\n", | |
" triton_params[\"k\"][\"zero\"],\n", | |
" group_size=quant_config['weight_quant_params']['group_size'],\n", | |
" fp8_fast_accum=False,\n", | |
" kernel_type=\"compute_bound\",\n", | |
" transposed=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"id": "cf9e24fe-ad1f-4c97-a0eb-7f13482d5b52", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"False" | |
] | |
}, | |
"execution_count": 52, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.equal(output_qkv[:,4096:5120], output_k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"id": "fbf45ba1-5282-4cc5-a96b-849d162f8adf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"output_v = triton_mixed_mm(x,\n", | |
" triton_params[\"v\"][\"Wq\"],\n", | |
" triton_params[\"v\"][\"scale\"],\n", | |
" triton_params[\"v\"][\"zero\"],\n", | |
" group_size=quant_config['weight_quant_params']['group_size'],\n", | |
" fp8_fast_accum=False,\n", | |
" kernel_type=\"compute_bound\",\n", | |
" transposed=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 54, | |
"id": "c3be901d-10e8-42ba-8392-ee48ba5b8967", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"False" | |
] | |
}, | |
"execution_count": 54, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.equal(output_qkv[:,5120:], output_v)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 57, | |
"id": "21dc7908-ce3e-4f60-a9ee-8dbd733a92ad", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"False" | |
] | |
}, | |
"execution_count": 57, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.equal(output_qkv, torch.cat([output_q, output_k, output_v], dim=1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a50c3def-6633-4ef8-b7f7-2bd8c8701d55", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c01f04a9-9f90-4ad0-9ad0-ec6c74804d05", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "badb7b0f-32c2-48e4-a1fa-27254cd548c0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment