Skip to content

Instantly share code, notes, and snippets.

@XinDongol
Created November 6, 2018 00:33
Show Gist options
  • Save XinDongol/6a7cd2c7d4ed1c9d63a742f523f9b24a to your computer and use it in GitHub Desktop.
Save XinDongol/6a7cd2c7d4ed1c9d63a742f523f9b24a to your computer and use it in GitHub Desktop.
class FoldedConv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=1, affine=True, bias=True, use_running=False):
super(FoldedConv2d, self).__init__()
self.use_running = use_running
self.bn = torch.nn.BatchNorm2d(out_channels, affine=affine)
self._weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
n = kernel_size * kernel_size * out_channels
self._weight.data.normal_(0, math.sqrt(2. / n))
self._bias = nn.Parameter(torch.Tensor(out_channels))
self._bias.data.normal_(0, 1)
self.stride = stride
self.padding = padding
self.out_channels = out_channels
def forward(self, x):
if not self.training:
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding)
if self.use_running:
weight = self.fold_weight()
bias = self.fold_bias()
else:
real_h = F.conv2d(x, self._weight, self._bias, self.stride, self.padding)
C = real_h.size(1)
self.real_output = self.bn(real_h)
mu = real_h.transpose(0, 1).contiguous().view(C, -1).mean(1)
var = real_h.transpose(0, 1).contiguous().view(C, -1).var(1)
weight = self.fold_weight(var)
bias = self.fold_bias(mu, var)
return F.conv2d(x, weight, bias, self.stride, self.padding)
def fold_weight(self, var=None):
if var is None:
var = self.bn.running_var
# w_conv = self._weight.clone().view(self.out_channels, -1)
w_conv = self._weight.clone()
# w_bn = torch.diag(self.bn.weight.div(torch.sqrt(self.bn.eps + var)))
# folded_weight = torch.mm(w_bn, w_conv).view(self._weight.size())
folded_weight = w_conv / torch.sqrt(var + self.bn.eps).view(-1, 1, 1, 1)
folded_weight = log_quantize.apply(folded_weight, 2, -6, 1)
return folded_weight
def fold_bias(self, mean=None, var=None):
if mean is None:
mean = self.bn.running_mean
if var is None:
var = self.bn.running_var
# b_bn = torch.sqrt(var + self.bn.eps)
# b_bn = self.bn.bias - self.bn.weight.mul(mean).div(b_bn)
b_bn = mean / torch.sqrt(var + self.bn.eps)
if self._bias is not None:
return self._bias.clone() - b_bn
else:
return b_bn
@property
def weight(self):
return self.fold_weight()
@property
def bias(self):
return self.fold_bias()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment