Last active
June 11, 2021 06:27
-
-
Save z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "de245c82", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "e6b14ae1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"backend = 'fbgemm'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "cc61857a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import nn\n", | |
"import torch.quantization as tq\n", | |
"\n", | |
"class Model(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" self.quant = tq.QuantStub()\n", | |
" self.linear1 = nn.Linear(128, 1024)\n", | |
" self.relu1 = nn.ReLU(inplace=False)\n", | |
" self.linear2 = nn.Linear(1024, 1024)\n", | |
" self.dequant = tq.DeQuantStub()\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x = self.quant(x)\n", | |
" x = self.linear1(x)\n", | |
" x = self.relu1(x)\n", | |
" x = self.linear2(x)\n", | |
" x = self.dequant(x)\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "40ebc075", | |
"metadata": {}, | |
"source": [ | |
"## Static quantization conversion" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "27e8bc6a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.ao.nn as ao_nn\n", | |
"\n", | |
"model = Model()\n", | |
"model.eval()\n", | |
"model.qconfig = tq.get_default_qconfig(backend)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f2301413", | |
"metadata": {}, | |
"source": [ | |
"### Prepare" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "43eb90f6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model(\n", | |
" (quant): QuantStub(\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (linear1): Linear(\n", | |
" in_features=128, out_features=1024, bias=True\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (relu1): ReLU()\n", | |
" (linear2): Linear(\n", | |
" in_features=1024, out_features=1024, bias=True\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (dequant): DeQuantStub()\n", | |
")\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/zafar/Git/pytorch-dev/pytorch/torch/quantization/observer.py:134: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.\n", | |
" warnings.warn(\n" | |
] | |
} | |
], | |
"source": [ | |
"tq.prepare(model, inplace=True)\n", | |
"print(model)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a019a6db", | |
"metadata": {}, | |
"source": [ | |
"### Calibrate" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "da434429", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model(\n", | |
" (quant): QuantStub(\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (linear1): Linear(\n", | |
" in_features=128, out_features=1024, bias=True\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (relu1): ReLU()\n", | |
" (linear2): Linear(\n", | |
" in_features=1024, out_features=1024, bias=True\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (dequant): DeQuantStub()\n", | |
")\n" | |
] | |
} | |
], | |
"source": [ | |
"model(torch.randn(128, 128))\n", | |
"print(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "38a9e24f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Linear(\n", | |
" in_features=128, out_features=1024, bias=True\n", | |
" (activation_post_process): HistogramObserver()\n", | |
")" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.linear1" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e1bccb1f", | |
"metadata": {}, | |
"source": [ | |
"### Convert" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "52db42c3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model(\n", | |
" (quant): Quantize(scale=tensor([0.0579]), zero_point=tensor([64]), dtype=torch.quint8)\n", | |
" (linear1): QuantizedLinear(in_features=128, out_features=1024, scale=0.040794678032398224, zero_point=64, qscheme=torch.per_channel_affine)\n", | |
" (relu1): ReLU()\n", | |
" (linear2): QuantizedLinear(in_features=1024, out_features=1024, scale=0.015324125066399574, zero_point=63, qscheme=torch.per_channel_affine)\n", | |
" (dequant): DeQuantize()\n", | |
")\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/zafar/Git/pytorch-dev/pytorch/torch/_tensor.py:557: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.\n", | |
"To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /home/zafar/Git/pytorch-dev/pytorch/aten/src/ATen/native/BinaryOps.cpp:461.)\n", | |
" return torch.floor_divide(self, other)\n" | |
] | |
} | |
], | |
"source": [ | |
"tq.convert(model, inplace=True)\n", | |
"print(model)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1690219b", | |
"metadata": {}, | |
"source": [ | |
"## Static sparse conversion" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "6eff751f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Model(\n", | |
" (quant): QuantStub()\n", | |
" (linear1): Linear(in_features=128, out_features=1024, bias=True)\n", | |
" (relu1): ReLU()\n", | |
" (linear2): Linear(in_features=1024, out_features=1024, bias=True)\n", | |
" (dequant): DeQuantStub()\n", | |
")" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import torch.ao.nn as ao_nn\n", | |
"\n", | |
"# We don't care about the quantstubs\n", | |
"model = Model()\n", | |
"model.eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "1dda117b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# We don't need prepare or convert, we just need to compute the mask\n", | |
"\n", | |
"# Step 1. Create a mask\n", | |
"model.linear1.register_buffer('mask', torch.ones(model.linear1.weight.shape))\n", | |
"model.linear2.register_buffer('mask', torch.ones(model.linear2.weight.shape))\n", | |
"\n", | |
"# Step 2. Compute some mask\n", | |
"model.linear1.mask = torch.randint(0, 2, model.linear1.mask.shape)\n", | |
"model.linear2.mask = torch.randint(0, 2, model.linear2.mask.shape)\n", | |
"\n", | |
"# Step 3. Replace the weight\n", | |
"model.linear1.weight.data *= model.linear1.mask\n", | |
"model.linear2.weight.data *= model.linear2.mask" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "deb28ff8", | |
"metadata": {}, | |
"source": [ | |
"## Quantized Sparse conversion" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "4ecefb4b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.ao.nn as ao_nn\n", | |
"\n", | |
"# We don't care about the quantstubs\n", | |
"model = Model()\n", | |
"model.eval()\n", | |
"model.qconfig = tq.get_default_qconfig(backend)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "1f2972be", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Step 1. Create a mask\n", | |
"model.linear1.register_buffer('mask', torch.ones(model.linear1.weight.shape))\n", | |
"model.linear2.register_buffer('mask', torch.ones(model.linear2.weight.shape))\n", | |
"\n", | |
"# Step 2. Compute some mask\n", | |
"model.linear1.mask = torch.randint(0, 2, model.linear1.mask.shape)\n", | |
"model.linear2.mask = torch.randint(0, 2, model.linear2.mask.shape)\n", | |
"\n", | |
"# Step 3. Replace the weight\n", | |
"model.linear1.weight.data *= model.linear1.mask\n", | |
"model.linear2.weight.data *= model.linear2.mask" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3652f6d2", | |
"metadata": {}, | |
"source": [ | |
"### Prepare" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"id": "5240927d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model(\n", | |
" (quant): QuantStub(\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (linear1): Linear(\n", | |
" in_features=128, out_features=1024, bias=True\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (relu1): ReLU()\n", | |
" (linear2): Linear(\n", | |
" in_features=1024, out_features=1024, bias=True\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (dequant): DeQuantStub()\n", | |
")\n" | |
] | |
} | |
], | |
"source": [ | |
"tq.prepare(model, inplace=True)\n", | |
"print(model)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "22499cc3", | |
"metadata": {}, | |
"source": [ | |
"### Calibrate" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"id": "9e36b5d2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model(\n", | |
" (quant): QuantStub(\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (linear1): Linear(\n", | |
" in_features=128, out_features=1024, bias=True\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (relu1): ReLU()\n", | |
" (linear2): Linear(\n", | |
" in_features=1024, out_features=1024, bias=True\n", | |
" (activation_post_process): HistogramObserver()\n", | |
" )\n", | |
" (dequant): DeQuantStub()\n", | |
")\n" | |
] | |
} | |
], | |
"source": [ | |
"model(torch.randn(128, 128))\n", | |
"print(model)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "47c5a2b4", | |
"metadata": {}, | |
"source": [ | |
"### Convert" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"id": "e2662d30", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model(\n", | |
" (quant): Quantize(scale=tensor([0.0585]), zero_point=tensor([65]), dtype=torch.quint8)\n", | |
" (linear1): QuantizedLinear(in_features=128, out_features=1024, scale=0.02753080427646637, zero_point=67, qscheme=torch.per_channel_affine)\n", | |
" (relu1): ReLU()\n", | |
" (linear2): QuantizedLinear(in_features=1024, out_features=1024, scale=0.007649691309779882, zero_point=64, qscheme=torch.per_channel_affine)\n", | |
" (dequant): DeQuantize()\n", | |
")\n" | |
] | |
} | |
], | |
"source": [ | |
"tq.convert(model, inplace=True)\n", | |
"print(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"id": "a3dae262", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.5018)" | |
] | |
}, | |
"execution_count": 43, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Should be around 50%\n", | |
"(model.linear1.weight() == 0).float().mean()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "dabece3e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment