Skip to content

Instantly share code, notes, and snippets.

@sandeepkumar-skb
Last active October 8, 2020 20:42
Show Gist options
  • Save sandeepkumar-skb/3302f758ccd79228e48b21ffacd7fb61 to your computer and use it in GitHub Desktop.
Save sandeepkumar-skb/3302f758ccd79228e48b21ffacd7fb61 to your computer and use it in GitHub Desktop.
GroupNorm implementation in Pytorch
import torch
import torch.nn as nn
class GroupNorm(nn.Module):
def __init__(self, num_groups, num_features, eps=1e-5):
super(GroupNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(1,num_features,1,1))
self.bias = nn.Parameter(torch.zeros(1,num_features,1,1))
self.num_groups = num_groups
self.eps = eps
def forward(self, x):
N,C,H,W = x.size()
G = self.num_groups
assert C % G == 0
x = x.view(N,G,-1)
mean = x.mean(-1, keepdim=True)
var = x.var(-1, keepdim=True)
x = (x-mean) / (var+self.eps).sqrt()
x = x.view(N,C,H,W)
return x * self.weight + self.bias
gn = GroupNorm(num_groups=32, num_channels=128)
nn_gn = nn.GroupNorm(num_groups=32, num_channels=128)
x = torch.rand([10,128,56,56])
x_gn = gn(x)
nn_x_gn = nn_gn(x)
@sandeepkumar-skb
Copy link
Author

result: tensor(7.0453e-05, grad_fn=)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment