Skip to content

Instantly share code, notes, and snippets.

@Smthri
Created August 3, 2020 19:35
Show Gist options
  • Save Smthri/7e862838093532f1aa38857d345e2816 to your computer and use it in GitHub Desktop.
Save Smthri/7e862838093532f1aa38857d345e2816 to your computer and use it in GitHub Desktop.
import os
import sys
import argparse
import numpy as np
import torch
import torch.quantization as tq
from pathlib import Path
from torch import nn
from _collections_abc import Iterable
from collections import OrderedDict
def get_test_model(input_shape):
model_ = nn.Conv2d(input_shape[0], 3, 3, bias=False)
for p in model_.parameters():
torch.nn.init.normal_(p, 0, 20)
model_ = quantize(model_, input_shape)
model_.eval()
return model_
def quantize(model, input_shape):
with torch.no_grad():
observer = tq.PerChannelMinMaxObserver()
model.qconfig = torch.quantization.QConfig(activation=tq.MinMaxObserver.with_args(reduce_range=True),
weight=observer.with_args(dtype=torch.qint8,
qscheme=torch.per_channel_affine,
reduce_range=True))
model = tq.QuantWrapper(model)
tq.prepare(model, inplace=True)
for i in range(1000):
x = torch.randn(2, *input_shape)
tmp = model(x)
tq.convert(model, inplace=True)
return model
input_shape = (5, 7, 7)
torch.manual_seed(42)
model = get_test_model(input_shape)
x = torch.randn(1, *input_shape)
print(x)
# step-by-step forward:
q_inp = model.quant(x)
q_outp = model.module(q_inp)
f_outp = model.dequant(q_outp)
inp_scale = q_inp.q_scale()
inp_zero_point = q_inp.q_zero_point()
outp_scale = q_outp.q_scale()
outp_zero_point = q_outp.q_zero_point()
q_inp = q_inp.int_repr().detach().numpy()[0]
q_outp = q_outp.int_repr().detach().numpy()[0]
# check step-by step forward correct:
print((abs(f_outp - model(x)) < 1e-9).any().item())
q_ker = model.module.weight().int_repr().detach().numpy().astype(np.float32)
ker_scales = model.module.weight().q_per_channel_scales().detach().numpy().astype(np.float32)
ker_zero_points = model.module.weight().q_per_channel_zero_points().detach().numpy().astype(np.float32)
f_ker = (q_ker - ker_zero_points.reshape(3, 1, 1, 1)) * ker_scales.reshape(3, 1, 1, 1)
f_inp = (np.float32(q_inp) - np.float32(inp_zero_point)) * np.float32(inp_scale)
res = np.zeros((3, input_shape[1] - 2, input_shape[2] - 2), dtype=np.float32)
for c in range(3):
for i in range(input_shape[1] - 2):
for j in range(input_shape[2] - 2):
res[c, i, j] = (f_inp[:, i:i+3, j:j+3] * f_ker[c]).sum()
print(f_outp.detach().numpy()[0] - res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment