Created
August 3, 2020 19:35
-
-
Save Smthri/7e862838093532f1aa38857d345e2816 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import argparse | |
import numpy as np | |
import torch | |
import torch.quantization as tq | |
from pathlib import Path | |
from torch import nn | |
from _collections_abc import Iterable | |
from collections import OrderedDict | |
def get_test_model(input_shape): | |
model_ = nn.Conv2d(input_shape[0], 3, 3, bias=False) | |
for p in model_.parameters(): | |
torch.nn.init.normal_(p, 0, 20) | |
model_ = quantize(model_, input_shape) | |
model_.eval() | |
return model_ | |
def quantize(model, input_shape): | |
with torch.no_grad(): | |
observer = tq.PerChannelMinMaxObserver() | |
model.qconfig = torch.quantization.QConfig(activation=tq.MinMaxObserver.with_args(reduce_range=True), | |
weight=observer.with_args(dtype=torch.qint8, | |
qscheme=torch.per_channel_affine, | |
reduce_range=True)) | |
model = tq.QuantWrapper(model) | |
tq.prepare(model, inplace=True) | |
for i in range(1000): | |
x = torch.randn(2, *input_shape) | |
tmp = model(x) | |
tq.convert(model, inplace=True) | |
return model | |
input_shape = (5, 7, 7) | |
torch.manual_seed(42) | |
model = get_test_model(input_shape) | |
x = torch.randn(1, *input_shape) | |
print(x) | |
# step-by-step forward: | |
q_inp = model.quant(x) | |
q_outp = model.module(q_inp) | |
f_outp = model.dequant(q_outp) | |
inp_scale = q_inp.q_scale() | |
inp_zero_point = q_inp.q_zero_point() | |
outp_scale = q_outp.q_scale() | |
outp_zero_point = q_outp.q_zero_point() | |
q_inp = q_inp.int_repr().detach().numpy()[0] | |
q_outp = q_outp.int_repr().detach().numpy()[0] | |
# check step-by step forward correct: | |
print((abs(f_outp - model(x)) < 1e-9).any().item()) | |
q_ker = model.module.weight().int_repr().detach().numpy().astype(np.float32) | |
ker_scales = model.module.weight().q_per_channel_scales().detach().numpy().astype(np.float32) | |
ker_zero_points = model.module.weight().q_per_channel_zero_points().detach().numpy().astype(np.float32) | |
f_ker = (q_ker - ker_zero_points.reshape(3, 1, 1, 1)) * ker_scales.reshape(3, 1, 1, 1) | |
f_inp = (np.float32(q_inp) - np.float32(inp_zero_point)) * np.float32(inp_scale) | |
res = np.zeros((3, input_shape[1] - 2, input_shape[2] - 2), dtype=np.float32) | |
for c in range(3): | |
for i in range(input_shape[1] - 2): | |
for j in range(input_shape[2] - 2): | |
res[c, i, j] = (f_inp[:, i:i+3, j:j+3] * f_ker[c]).sum() | |
print(f_outp.detach().numpy()[0] - res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment