Created
February 12, 2019 20:42
-
-
Save ajbrock/d4a52ea75a7284ebafbccf87cf63414c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Manual BN | |
# Calculate means and variances using mean-of-squares mins mean-squared | |
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): | |
# # Calculate expected value of x (m) and expected value of x**2 (m2) | |
# Mean of x | |
m = torch.mean(x, [0, 2, 3], keepdim=True) | |
# Mean of x squared | |
m2 = torch.mean(x ** 2, [0, 2, 3], keepdim=True) | |
# Calculate variance as mean of squared minus mean squared. | |
var = (m2 - m **2) | |
if return_mean_var: | |
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() | |
else: | |
return fused_bn(x, m, var, gain, bias, eps) | |
# Apply scale and shift--if gain and bias are provided, fuse them here | |
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): | |
# Prepare scale | |
scale = torch.rsqrt(var + eps) | |
# If a gain is provided, use it | |
if gain is not None: | |
scale = scale * gain | |
# Prepare shift | |
shift = mean * scale | |
# If bias is provided, use it | |
if bias is not None: | |
shift = shift - bias | |
return x * scale - shift | |
#return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias | |
#### This is just a module wrapper that does bookkeeping, the above two functions are what do the batchnorming | |
# My batchnorm, supports standing stats | |
class myBN(nn.Module): | |
def __init__(self, num_channels, eps=1e-5, momentum=0.1): | |
super(myBN, self).__init__() | |
# momentum for updating running stats | |
self.momentum = momentum | |
# epsilon to avoid dividing by 0 | |
self.eps = eps | |
# Momentum | |
self.momentum = momentum | |
# Register buffers | |
self.register_buffer('stored_mean', torch.zeros(num_channels)) | |
self.register_buffer('stored_var', torch.ones(num_channels)) | |
self.register_buffer('accumulation_counter', torch.zeros(1)) | |
# Accumulate running means and vars | |
self.accumulate_standing = False | |
# reset standing stats | |
def reset_stats(self): | |
self.stored_mean[:] = 0 | |
self.stored_var[:] = 0 | |
self.accumulation_counter[:] = 0 | |
def forward(self, x, gain, bias): | |
if self.training: | |
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) | |
if self.accumulate_standing: | |
self.stored_mean[:] = self.stored_mean + mean.data | |
self.stored_var[:] = self.stored_var + var.data | |
self.accumulation_counter += 1.0 | |
# If not accumulating standing stats, take running averages | |
else: | |
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum | |
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum | |
return out | |
# If not in training mode, don't update stats | |
else: | |
mean = self.stored_mean.view(1, -1, 1, 1) | |
var = self.stored_var.view(1, -1, 1, 1) | |
# If using standing stats, divide them by the accumulation counter | |
if self.accumulate_standing: | |
mean = mean / self.accumulation_counter | |
var = var / self.accumulation_counter | |
return fused_bn(x, mean, var, gain, bias, self.eps) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment