Last active
April 25, 2019 11:49
-
-
Save XinDongol/49b67beabf75b0fe3d2522120aa2dab9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.autograd import Variable | |
import torch | |
from torch import nn | |
from collections import OrderedDict | |
from IPython import embed | |
from torch.autograd.function import InplaceFunction, Function | |
import torch.nn.functional as F | |
import math | |
def _mean(p, dim): | |
"""Computes the mean over all dimensions except dim""" | |
if dim is None: | |
return p.mean() | |
elif dim == 0: | |
output_size = (p.size(0),) + (1,) * (p.dim() - 1) | |
return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) | |
elif dim == p.dim() - 1: | |
output_size = (1,) * (p.dim() - 1) + (p.size(-1),) | |
return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) | |
else: | |
return _mean(p.transpose(0, dim), 0).transpose(0, dim) | |
class UniformQuantize(InplaceFunction): | |
@classmethod | |
def forward(cls, ctx, input, num_bits=8, min_value=None, max_value=None, | |
stochastic=False, inplace=False, enforce_true_zero=False, num_chunks=None, out_half=False): | |
num_chunks = input.shape[0] if num_chunks is None else num_chunks | |
if min_value is None or max_value is None: | |
B = input.shape[0] | |
y = input.view(B // num_chunks, -1) | |
if min_value is None: | |
min_value = y.min(-1)[0].mean(-1) # C | |
#min_value = float(input.view(input.size(0), -1).min(-1)[0].mean()) | |
if max_value is None: | |
#max_value = float(input.view(input.size(0), -1).max(-1)[0].mean()) | |
max_value = y.max(-1)[0].mean(-1) # C | |
ctx.inplace = inplace | |
ctx.num_bits = num_bits | |
ctx.min_value = min_value | |
ctx.max_value = max_value | |
ctx.stochastic = stochastic | |
if ctx.inplace: | |
ctx.mark_dirty(input) | |
output = input | |
else: | |
output = input.clone() | |
qmin = 0. | |
qmax = 2.**num_bits - 1. | |
#import pdb; pdb.set_trace() | |
scale = (max_value - min_value) / (qmax - qmin) | |
scale = max(scale, 1e-8) | |
if enforce_true_zero: | |
initial_zero_point = qmin - min_value / scale | |
zero_point = 0. | |
# make zero exactly represented | |
if initial_zero_point < qmin: | |
zero_point = qmin | |
elif initial_zero_point > qmax: | |
zero_point = qmax | |
else: | |
zero_point = initial_zero_point | |
zero_point = int(zero_point) | |
output.div_(scale).add_(zero_point) | |
else: | |
output.add_(-min_value).div_(scale).add_(qmin) | |
if ctx.stochastic: | |
noise = output.new(output.shape).uniform_(-0.5, 0.5) | |
output.add_(noise) | |
output.clamp_(qmin, qmax).round_() # quantize | |
if enforce_true_zero: | |
output.add_(-zero_point).mul_(scale) # dequantize | |
else: | |
output.add_(-qmin).mul_(scale).add_(min_value) # dequantize | |
if out_half and num_bits <= 16: | |
output = output.half() | |
return output | |
@staticmethod | |
def backward(ctx, grad_output): | |
# straight-through estimator | |
grad_input = grad_output | |
return grad_input, None, None, None, None, None, None | |
class UniformQuantizeGrad(InplaceFunction): | |
@classmethod | |
def forward(cls, ctx, input, num_bits=8, min_value=None, max_value=None, stochastic=True, inplace=False): | |
ctx.inplace = inplace | |
ctx.num_bits = num_bits | |
ctx.min_value = min_value | |
ctx.max_value = max_value | |
ctx.stochastic = stochastic | |
return input | |
@staticmethod | |
def backward(ctx, grad_output): | |
if ctx.min_value is None: | |
min_value = float(grad_output.min()) | |
# min_value = float(grad_output.view( | |
# grad_output.size(0), -1).min(-1)[0].mean()) | |
else: | |
min_value = ctx.min_value | |
if ctx.max_value is None: | |
max_value = float(grad_output.max()) | |
# max_value = float(grad_output.view( | |
# grad_output.size(0), -1).max(-1)[0].mean()) | |
else: | |
max_value = ctx.max_value | |
grad_input = UniformQuantize().apply(grad_output, ctx.num_bits, | |
min_value, max_value, ctx.stochastic, ctx.inplace) | |
return grad_input, None, None, None, None, None | |
def quantize(x, num_bits=8, min_value=None, max_value=None, num_chunks=None, stochastic=False, inplace=False): | |
return UniformQuantize().apply(x, num_bits, min_value, max_value, num_chunks, stochastic, inplace) | |
def quantize_grad(x, num_bits=8, min_value=None, max_value=None, stochastic=True, inplace=False): | |
return UniformQuantizeGrad().apply(x, num_bits, min_value, max_value, stochastic, inplace) | |
class QuantMeasure(nn.Module): | |
"""docstring for QuantMeasure.""" | |
def __init__(self, num_bits=8, momentum=0.1, init_running_min = -2.0, init_running_max = 2.0): | |
super(QuantMeasure, self).__init__() | |
self.register_buffer('running_min', torch.zeros(1)) | |
self.register_buffer('running_max', torch.zeros(1)) | |
self.momentum = momentum | |
self.num_bits = num_bits | |
self.running_min = init_running_min | |
self.running_max = init_running_max | |
def forward(self, input): | |
if self.training: | |
min_value = input.detach().view( | |
input.size(0), -1).min(-1)[0].mean() | |
max_value = input.detach().view( | |
input.size(0), -1).max(-1)[0].mean() | |
self.running_min.mul_(self.momentum).add_( | |
min_value * (1 - self.momentum)) | |
self.running_max.mul_(self.momentum).add_( | |
max_value * (1 - self.momentum)) | |
else: | |
min_value = self.running_min | |
max_value = self.running_max | |
return quantize(input, self.num_bits, min_value=float(min_value), max_value=float(max_value), num_chunks=16) | |
class QConv2d(nn.Conv2d): | |
"""docstring for QConv2d.""" | |
def __init__(self, in_channels, out_channels, kernel_size, | |
stride=1, padding=0, dilation=1, groups=1, bias=True, num_bits=8, num_bits_weight=None, num_bits_grad=None, biprecision=False): | |
super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, | |
stride, padding, dilation, groups, bias) | |
self.num_bits = num_bits | |
self.num_bits_weight = num_bits_weight or num_bits | |
self.num_bits_grad = num_bits_grad | |
self.quantize_input = QuantMeasure(self.num_bits) | |
self.biprecision = biprecision | |
def forward(self, input): | |
qinput = self.quantize_input(input) | |
qweight = quantize(self.weight, num_bits=self.num_bits_weight, | |
min_value=float(self.weight.min()), | |
max_value=float(self.weight.max())) | |
if self.bias is not None: | |
qbias = quantize(self.bias, num_bits=self.num_bits_weight) | |
else: | |
qbias = None | |
if not self.biprecision or self.num_bits_grad is None: | |
output = F.conv2d(qinput, qweight, qbias, self.stride, | |
self.padding, self.dilation, self.groups) | |
if self.num_bits_grad is not None: | |
output = quantize_grad(output, num_bits=self.num_bits_grad) | |
else: | |
output = conv2d_biprec(qinput, qweight, qbias, self.stride, | |
self.padding, self.dilation, self.groups, num_bits_grad=self.num_bits_grad) | |
return output | |
def compute_integral_part(input, overflow_rate): | |
abs_value = input.abs().view(-1) | |
sorted_value = abs_value.sort(dim=0, descending=True)[0] | |
split_idx = int(overflow_rate * len(sorted_value)) | |
v = sorted_value[split_idx] | |
if isinstance(v, Variable): | |
v = v.data.cpu().numpy()[0] | |
sf = math.ceil(math.log2(v+1e-12)) | |
return sf | |
def linear_quantize(input, sf, bits): | |
assert bits >= 1, bits | |
if bits == 1: | |
return torch.sign(input) - 1 | |
delta = math.pow(2.0, -sf) | |
bound = math.pow(2.0, bits-1) | |
min_val = - bound | |
max_val = bound - 1 | |
rounded = torch.floor(input / delta + 0.5) | |
clipped_value = torch.clamp(rounded, min_val, max_val) * delta | |
return clipped_value | |
def log_minmax_quantize(input, bits): | |
assert bits >= 1, bits | |
if bits == 1: | |
return torch.sign(input), 0.0, 0.0 | |
s = torch.sign(input) | |
input0 = torch.log(torch.abs(input) + 1e-20) | |
v = min_max_quantize(input0, bits) | |
v = torch.exp(v) * s | |
return v | |
def log_linear_quantize(input, sf, bits): | |
assert bits >= 1, bits | |
if bits == 1: | |
return torch.sign(input), 0.0, 0.0 | |
s = torch.sign(input) | |
input0 = torch.log(torch.abs(input) + 1e-20) | |
v = linear_quantize(input0, sf, bits) | |
v = torch.exp(v) * s | |
return v | |
def min_max_quantize(input, bits): | |
assert bits >= 1, bits | |
if bits == 1: | |
return torch.sign(input) - 1 | |
min_val, max_val = input.min(), input.max() | |
if isinstance(min_val, Variable): | |
max_val = float(max_val.data.cpu().numpy()[0]) | |
min_val = float(min_val.data.cpu().numpy()[0]) | |
input_rescale = (input - min_val) / (max_val - min_val) | |
n = math.pow(2.0, bits) - 1 | |
v = torch.floor(input_rescale * n + 0.5) / n | |
v = v * (max_val - min_val) + min_val | |
return v | |
def tanh_quantize(input, bits): | |
assert bits >= 1, bits | |
if bits == 1: | |
return torch.sign(input) | |
input = torch.tanh(input) # [-1, 1] | |
input_rescale = (input + 1.0) / 2 #[0, 1] | |
n = math.pow(2.0, bits) - 1 | |
v = torch.floor(input_rescale * n + 0.5) / n | |
v = 2 * v - 1 # [-1, 1] | |
v = 0.5 * torch.log((1 + v) / (1 - v)) # arctanh | |
return v | |
import torch.nn as nn | |
class My_Layer(nn.Module): | |
def __init__(self, m, num_bits): | |
super(My_Layer, self).__init__() | |
self.quantize_input = QuantMeasure(num_bits) | |
self.m = m | |
def forward(self, x): | |
#print('hahhahaha!') | |
x = self.quantize_input(x) | |
x = self.m(x) | |
return x | |
def replace(model): | |
if len(model._modules)==0: | |
return model | |
else: | |
for i, m in model._modules.items(): | |
if if_target_layer: | |
model._modules[i] = My_Layer(m, 8) | |
else: | |
m = replace(m) | |
#replace(model) | |
def if_target_layer(m): | |
''' | |
define your own identify function | |
''' | |
return isinstance(m, nn.Linear) | |
def duplicate_model_with_quant(model, **kwargs): | |
"""assume that original model has at least a nn.Sequential | |
you can use this function to implement different scheme | |
""" | |
assert kwargs['type'] in ['linear', 'minmax', 'log', 'tanh'] | |
#print(kwargs) | |
if isinstance(model, nn.Sequential): | |
# if it is sequence, we build a new sequence upon it | |
l = OrderedDict() # the new sequence | |
for k, v in model._modules.items(): # inside the sequence | |
if not isinstance(v, nn.Sequential): | |
l[k] = v # add the original layer into this new sequence | |
if isinstance(v, nn.ReLU): | |
quant_layer = nn.Tanh() | |
l['{}_{}_quant'.format(k, kwargs['type'])] = quant_layer # add the new layer into the new sequence | |
else: | |
l[k] = duplicate_model_with_quant(v, **kwargs) | |
m = nn.Sequential(l) | |
return m | |
else: | |
# if not sequence, go deeper to search | |
for k, v in model._modules.items(): | |
model._modules[k] = duplicate_model_with_quant(v, **kwargs) | |
return model | |
class BinOp(): | |
def __init__(self, model): | |
''' | |
parameters: | |
self.saved_params: list of parameters (float) | |
self.target_modules: list of parameters (quantised) | |
self.num_of_params: len of self.target_modules | |
self.quant_layer_list: list of layers (which need quantization) | |
''' | |
self.model = model | |
self.saved_params = [] | |
self.target_modules = [] | |
self.get_quant_layer() | |
self.num_of_params = len(self.quant_layer_list) | |
for m in self.quant_layer_list: | |
tmp = m.weight.data.clone() | |
self.saved_params.append(tmp) | |
self.target_modules.append(m.weight) | |
def get_quant_layer(self): | |
''' | |
output: self.quant_layer_list | |
use this func to gather layers which are needed to be quantized | |
''' | |
self.quant_layer_list = [] | |
for index, m in enumerate(self.model.modules()): # scane all layers | |
if isinstance(m, nn.Conv2d): # choose what layers you want to quant | |
self.quant_layer_list.append(m) | |
print('Quantizing the ', index,' layer: ', m) | |
else: | |
print('Ignoring the ', index,' layer: ', m) | |
print('Total quantized layers: ', len(self.quant_layer_list)) | |
return self.quant_layer_list | |
def binarization(self): | |
self.meancenterConvParams() | |
self.clampConvParams() | |
self.save_params() | |
self.binarizeConvParams() | |
def meancenterConvParams(self): | |
for index in range(self.num_of_params): | |
s = self.target_modules[index].data.size() | |
negMean = self.target_modules[index].data.mean(1, keepdim=True).\ | |
mul(-1).expand_as(self.target_modules[index].data) | |
self.target_modules[index].data = self.target_modules[index].data.add(negMean) | |
def clampConvParams(self): | |
for index in range(self.num_of_params): | |
self.target_modules[index].data.clamp(-1.0, 1.0, | |
out = self.target_modules[index].data) | |
def save_params(self): | |
for index in range(self.num_of_params): | |
self.saved_params[index].copy_(self.target_modules[index].data) | |
def binarizeConvParams(self): | |
for index in range(self.num_of_params): | |
n = self.target_modules[index].data[0].nelement() | |
s = self.target_modules[index].data.size() | |
m = self.target_modules[index].data.norm(1, 3, keepdim=True)\ | |
.sum(2, keepdim=True).sum(1, keepdim=True).div(n) | |
self.target_modules[index].data.sign()\ | |
.mul(m.expand(s), out=self.target_modules[index].data) | |
def restore(self): | |
for index in range(self.num_of_params): | |
self.target_modules[index].data.copy_(self.saved_params[index]) | |
def updateBinaryGradWeight(self): | |
for index in range(self.num_of_params): | |
weight = self.target_modules[index].data | |
n = weight[0].nelement() | |
s = weight.size() | |
m = weight.norm(1, 3, keepdim=True)\ | |
.sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(s) | |
m[weight.lt(-1.0)] = 0 | |
m[weight.gt(1.0)] = 0 | |
# m = m.add(1.0/n).mul(1.0-1.0/s[1]).mul(n) | |
# self.target_modules[index].grad.data = \ | |
# self.target_modules[index].grad.data.mul(m) | |
m = m.mul(self.target_modules[index].grad.data) | |
m_add = weight.sign().mul(self.target_modules[index].grad.data) | |
m_add = m_add.sum(3, keepdim=True)\ | |
.sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(s) | |
m_add = m_add.mul(weight.sign()) | |
self.target_modules[index].grad.data = m.add(m_add).mul(1.0-1.0/s[1]).mul(n) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment