Skip to content

Instantly share code, notes, and snippets.

@kashif
Last active April 10, 2020 22:40
Show Gist options
  • Save kashif/ff44b17a6da18ec5128678d100c3818f to your computer and use it in GitHub Desktop.
Save kashif/ff44b17a6da18ec5128678d100c3818f to your computer and use it in GitHub Desktop.
EvoNorm-S0 in PyTorch from https://arxiv.org/pdf/2004.02967.pdf
import torch
import torch.nn as nn
class EvoNorm2d(nn.Module):
__constants__ = ['num_features', 'eps', 'nonlinearity']
def __init__(self, num_features, eps=1e-5, nonlinearity=True):
super(EvoNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.nonlinearity = nonlinearity
self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
if self.nonlinearity:
self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.nonlinearity:
nn.init.ones_(self.v)
def group_std(self, x, groups=32):
N, C, H, W = x.shape
x = torch.reshape(x, (N, groups, C//groups, H, W))
std = torch.std(x, (2, 3, 4), keepdim=True).expand_as(x)
return torch.reshape(std + self.eps, (N, C, H, W))
def forward(self, x):
if self.nonlinearity:
num = x * torch.sigmoid(self.v * x)
return num/self.group_std(x) * self.weight + self.bias
else:
return x * self.weight + self.bias
@pinouchon
Copy link

pinouchon commented Apr 9, 2020

Do you have the EvoNorm-B0 by any chance? It looks like this is EvoNorm-S0. I changed it to 1d and it seems to work fine, but I would still prefer the batch version

@digantamisra98
Copy link

@pinouchon https://github.com/digantamisra98/EvoNorm I have EvoNorm B0 here, however I just have one error of shape mismatch in the running variance calculation to solve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment