Skip to content

Instantly share code, notes, and snippets.

@xmodar
Created November 3, 2021 23:24
Show Gist options
  • Save xmodar/291248803b87cc586edebb570f098d7e to your computer and use it in GitHub Desktop.
Save xmodar/291248803b87cc586edebb570f098d7e to your computer and use it in GitHub Desktop.
"""Invertible BatchNorm"""
import torch
from torch import nn
class NonZero(nn.Module):
"""Parameterization to force the values to be nonzero"""
def __init__(self, eps=1e-5, preserve_sign=True):
super().__init__()
self.eps, self.preserve_sign = eps, preserve_sign
def forward(self, inputs):
"""Perform the forward pass"""
eps = torch.tensor(self.eps, dtype=inputs.dtype, device=inputs.device)
if self.preserve_sign:
eps = torch.where(inputs < 0, -eps, eps)
return inputs.where(inputs.detach().abs() > self.eps, eps)
class InvertibleBatchNorm(nn.Module):
"""Invertible batchnorm layer (inverse doesn't update running stats)"""
def __init__(self, batch_norm):
super().__init__()
self.batch_norm = batch_norm
self.non_zero = NonZero(self.eps / 100)
def forward(self, inputs):
"""Perform the forward pass"""
# get/compute stats
if self.training or not self.track_running_stats:
dim = tuple(set(range(inputs.ndim)) - {1})
var, mean = torch.var_mean(inputs.detach(), dim, unbiased=False)
else:
var, mean = self.running_var, self.running_mean
# compute output
shape = (mean.numel(), ) + (1, ) * (inputs.ndim - 2)
out = (inputs - mean.view(shape)) / (var.view(shape) + self.eps).sqrt()
if self.affine:
self.weight.data.copy_(self.non_zero(self.weight.data))
out = self.weight.view(shape) * out + self.bias.view(shape)
# update stats
if self.training and self.track_running_stats:
self.batch_norm.num_batches_tracked = self.num_batches_tracked + 1
if self.momentum is None:
factor = 1 / self.num_batches_tracked
else:
factor = self.momentum
unbias = inputs.numel() / (inputs.numel() - var.numel())
self.running_mean.mul_(1 - factor).add_(factor * mean)
self.running_var.mul_(1 - factor).add_(factor * unbias * var)
return out, mean, var
def inverse(self, inputs, mean, var):
"""Perform the inverse pass"""
shape = (mean.numel(), ) + (1, ) * (inputs.ndim - 2)
if self.affine:
inputs = (inputs - self.bias.view(shape)) / self.weight.view(shape)
return inputs * (var.view(shape) + self.eps).sqrt() + mean.view(shape)
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.batch_norm, name)
@torch.no_grad()
def _test():
count, total = 0, 1000
for _ in range(total):
norm = InvertibleBatchNorm(nn.BatchNorm2d(3, momentum=1))
if norm.affine:
norm.weight.detach().normal_(0, 10)
norm.bias.detach().normal_()
shape = (2, norm.batch_norm.num_features, 5, 5)
if norm.track_running_stats:
for _ in range(3):
norm(torch.randn(shape))
norm.eval()
inputs = torch.randn(shape)
count += torch.allclose(inputs, norm.inverse(*norm(inputs)), atol=1e-5)
print(f'correct {count / total * 100:.2f}% of the time')
if __name__ == '__main__':
_test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment