Created
November 6, 2018 00:33
-
-
Save XinDongol/6a7cd2c7d4ed1c9d63a742f523f9b24a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FoldedConv2d(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | |
padding=1, affine=True, bias=True, use_running=False): | |
super(FoldedConv2d, self).__init__() | |
self.use_running = use_running | |
self.bn = torch.nn.BatchNorm2d(out_channels, affine=affine) | |
self._weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size)) | |
n = kernel_size * kernel_size * out_channels | |
self._weight.data.normal_(0, math.sqrt(2. / n)) | |
self._bias = nn.Parameter(torch.Tensor(out_channels)) | |
self._bias.data.normal_(0, 1) | |
self.stride = stride | |
self.padding = padding | |
self.out_channels = out_channels | |
def forward(self, x): | |
if not self.training: | |
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding) | |
if self.use_running: | |
weight = self.fold_weight() | |
bias = self.fold_bias() | |
else: | |
real_h = F.conv2d(x, self._weight, self._bias, self.stride, self.padding) | |
C = real_h.size(1) | |
self.real_output = self.bn(real_h) | |
mu = real_h.transpose(0, 1).contiguous().view(C, -1).mean(1) | |
var = real_h.transpose(0, 1).contiguous().view(C, -1).var(1) | |
weight = self.fold_weight(var) | |
bias = self.fold_bias(mu, var) | |
return F.conv2d(x, weight, bias, self.stride, self.padding) | |
def fold_weight(self, var=None): | |
if var is None: | |
var = self.bn.running_var | |
# w_conv = self._weight.clone().view(self.out_channels, -1) | |
w_conv = self._weight.clone() | |
# w_bn = torch.diag(self.bn.weight.div(torch.sqrt(self.bn.eps + var))) | |
# folded_weight = torch.mm(w_bn, w_conv).view(self._weight.size()) | |
folded_weight = w_conv / torch.sqrt(var + self.bn.eps).view(-1, 1, 1, 1) | |
folded_weight = log_quantize.apply(folded_weight, 2, -6, 1) | |
return folded_weight | |
def fold_bias(self, mean=None, var=None): | |
if mean is None: | |
mean = self.bn.running_mean | |
if var is None: | |
var = self.bn.running_var | |
# b_bn = torch.sqrt(var + self.bn.eps) | |
# b_bn = self.bn.bias - self.bn.weight.mul(mean).div(b_bn) | |
b_bn = mean / torch.sqrt(var + self.bn.eps) | |
if self._bias is not None: | |
return self._bias.clone() - b_bn | |
else: | |
return b_bn | |
@property | |
def weight(self): | |
return self.fold_weight() | |
@property | |
def bias(self): | |
return self.fold_bias() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment