Skip to content

Instantly share code, notes, and snippets.

@z-a-f
Last active June 11, 2021 06:27
Show Gist options
  • Save z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0 to your computer and use it in GitHub Desktop.
Save z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "de245c82",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e6b14ae1",
"metadata": {},
"outputs": [],
"source": [
"backend = 'fbgemm'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cc61857a",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.quantization as tq\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.quant = tq.QuantStub()\n",
" self.linear1 = nn.Linear(128, 1024)\n",
" self.relu1 = nn.ReLU(inplace=False)\n",
" self.linear2 = nn.Linear(1024, 1024)\n",
" self.dequant = tq.DeQuantStub()\n",
" \n",
" def forward(self, x):\n",
" x = self.quant(x)\n",
" x = self.linear1(x)\n",
" x = self.relu1(x)\n",
" x = self.linear2(x)\n",
" x = self.dequant(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "40ebc075",
"metadata": {},
"source": [
"## Static quantization conversion"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "27e8bc6a",
"metadata": {},
"outputs": [],
"source": [
"import torch.ao.nn as ao_nn\n",
"\n",
"model = Model()\n",
"model.eval()\n",
"model.qconfig = tq.get_default_qconfig(backend)"
]
},
{
"cell_type": "markdown",
"id": "f2301413",
"metadata": {},
"source": [
"### Prepare"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "43eb90f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model(\n",
" (quant): QuantStub(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (linear1): Linear(\n",
" in_features=128, out_features=1024, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (relu1): ReLU()\n",
" (linear2): Linear(\n",
" in_features=1024, out_features=1024, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (dequant): DeQuantStub()\n",
")\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/zafar/Git/pytorch-dev/pytorch/torch/quantization/observer.py:134: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.\n",
" warnings.warn(\n"
]
}
],
"source": [
"tq.prepare(model, inplace=True)\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "a019a6db",
"metadata": {},
"source": [
"### Calibrate"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "da434429",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model(\n",
" (quant): QuantStub(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (linear1): Linear(\n",
" in_features=128, out_features=1024, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (relu1): ReLU()\n",
" (linear2): Linear(\n",
" in_features=1024, out_features=1024, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (dequant): DeQuantStub()\n",
")\n"
]
}
],
"source": [
"model(torch.randn(128, 128))\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "38a9e24f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Linear(\n",
" in_features=128, out_features=1024, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
")"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.linear1"
]
},
{
"cell_type": "markdown",
"id": "e1bccb1f",
"metadata": {},
"source": [
"### Convert"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "52db42c3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model(\n",
" (quant): Quantize(scale=tensor([0.0579]), zero_point=tensor([64]), dtype=torch.quint8)\n",
" (linear1): QuantizedLinear(in_features=128, out_features=1024, scale=0.040794678032398224, zero_point=64, qscheme=torch.per_channel_affine)\n",
" (relu1): ReLU()\n",
" (linear2): QuantizedLinear(in_features=1024, out_features=1024, scale=0.015324125066399574, zero_point=63, qscheme=torch.per_channel_affine)\n",
" (dequant): DeQuantize()\n",
")\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/zafar/Git/pytorch-dev/pytorch/torch/_tensor.py:557: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.\n",
"To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /home/zafar/Git/pytorch-dev/pytorch/aten/src/ATen/native/BinaryOps.cpp:461.)\n",
" return torch.floor_divide(self, other)\n"
]
}
],
"source": [
"tq.convert(model, inplace=True)\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "1690219b",
"metadata": {},
"source": [
"## Static sparse conversion"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "6eff751f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Model(\n",
" (quant): QuantStub()\n",
" (linear1): Linear(in_features=128, out_features=1024, bias=True)\n",
" (relu1): ReLU()\n",
" (linear2): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dequant): DeQuantStub()\n",
")"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch.ao.nn as ao_nn\n",
"\n",
"# We don't care about the quantstubs\n",
"model = Model()\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "1dda117b",
"metadata": {},
"outputs": [],
"source": [
"# We don't need prepare or convert, we just need to compute the mask\n",
"\n",
"# Step 1. Create a mask\n",
"model.linear1.register_buffer('mask', torch.ones(model.linear1.weight.shape))\n",
"model.linear2.register_buffer('mask', torch.ones(model.linear2.weight.shape))\n",
"\n",
"# Step 2. Compute some mask\n",
"model.linear1.mask = torch.randint(0, 2, model.linear1.mask.shape)\n",
"model.linear2.mask = torch.randint(0, 2, model.linear2.mask.shape)\n",
"\n",
"# Step 3. Replace the weight\n",
"model.linear1.weight.data *= model.linear1.mask\n",
"model.linear2.weight.data *= model.linear2.mask"
]
},
{
"cell_type": "markdown",
"id": "deb28ff8",
"metadata": {},
"source": [
"## Quantized Sparse conversion"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "4ecefb4b",
"metadata": {},
"outputs": [],
"source": [
"import torch.ao.nn as ao_nn\n",
"\n",
"# We don't care about the quantstubs\n",
"model = Model()\n",
"model.eval()\n",
"model.qconfig = tq.get_default_qconfig(backend)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "1f2972be",
"metadata": {},
"outputs": [],
"source": [
"# Step 1. Create a mask\n",
"model.linear1.register_buffer('mask', torch.ones(model.linear1.weight.shape))\n",
"model.linear2.register_buffer('mask', torch.ones(model.linear2.weight.shape))\n",
"\n",
"# Step 2. Compute some mask\n",
"model.linear1.mask = torch.randint(0, 2, model.linear1.mask.shape)\n",
"model.linear2.mask = torch.randint(0, 2, model.linear2.mask.shape)\n",
"\n",
"# Step 3. Replace the weight\n",
"model.linear1.weight.data *= model.linear1.mask\n",
"model.linear2.weight.data *= model.linear2.mask"
]
},
{
"cell_type": "markdown",
"id": "3652f6d2",
"metadata": {},
"source": [
"### Prepare"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "5240927d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model(\n",
" (quant): QuantStub(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (linear1): Linear(\n",
" in_features=128, out_features=1024, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (relu1): ReLU()\n",
" (linear2): Linear(\n",
" in_features=1024, out_features=1024, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (dequant): DeQuantStub()\n",
")\n"
]
}
],
"source": [
"tq.prepare(model, inplace=True)\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "22499cc3",
"metadata": {},
"source": [
"### Calibrate"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "9e36b5d2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model(\n",
" (quant): QuantStub(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (linear1): Linear(\n",
" in_features=128, out_features=1024, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (relu1): ReLU()\n",
" (linear2): Linear(\n",
" in_features=1024, out_features=1024, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (dequant): DeQuantStub()\n",
")\n"
]
}
],
"source": [
"model(torch.randn(128, 128))\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "47c5a2b4",
"metadata": {},
"source": [
"### Convert"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "e2662d30",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model(\n",
" (quant): Quantize(scale=tensor([0.0585]), zero_point=tensor([65]), dtype=torch.quint8)\n",
" (linear1): QuantizedLinear(in_features=128, out_features=1024, scale=0.02753080427646637, zero_point=67, qscheme=torch.per_channel_affine)\n",
" (relu1): ReLU()\n",
" (linear2): QuantizedLinear(in_features=1024, out_features=1024, scale=0.007649691309779882, zero_point=64, qscheme=torch.per_channel_affine)\n",
" (dequant): DeQuantize()\n",
")\n"
]
}
],
"source": [
"tq.convert(model, inplace=True)\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "a3dae262",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.5018)"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Should be around 50%\n",
"(model.linear1.weight() == 0).float().mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dabece3e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment