Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Last active December 20, 2022 07:11
Show Gist options
  • Save chenyaofo/a5c32c344e8a29e37ecf919094519c37 to your computer and use it in GitHub Desktop.
Save chenyaofo/a5c32c344e8a29e37ecf919094519c37 to your computer and use it in GitHub Desktop.
Solution of TBR.
import torch
import torch.nn as nn
from torchvision.models import resnet50
import types
import copy
def tbr_bn_forward_impl(self: nn.BatchNorm2d, x: torch.Tensor):
batch_var, batch_mean = torch.var_mean(x, dim=(0, 2, 3), keepdim=True)
batch_std = torch.sqrt(batch_var+self.eps)
if self.running_mean is None:
self.running_mean, self.running_var = batch_mean.clone().detach(), batch_var.clone().detach()
self.running_mean, self.running_var = self.running_mean.view(1, -1, 1, 1), self.running_var.view(1, -1, 1, 1)
r = batch_std.detach() / torch.sqrt(self.running_var+self.eps)
d = (batch_mean.detach() - self.running_mean) / torch.sqrt(self.running_var+self.eps)
x = ((x - batch_mean) / batch_std) * r + d
self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)
self.running_var += self.momentum * (batch_var.detach() - self.running_var)
x = self.weight.view(1, -1, 1, 1) * x + self.bias.view(1, -1, 1, 1)
return x
def normal_bn_forward_impl(self: nn.BatchNorm2d, x: torch.Tensor):
x = ((x - self.running_mean.view(1, -1, 1, 1)) / torch.sqrt(self.running_var.view(1, -1, 1, 1))+self.eps)
x = self.weight.view(1, -1, 1, 1) * x + self.bias.view(1, -1, 1, 1)
return x
model = resnet50(pretrained=True)
bn = model.bn1
print(bn.weight)
bn1 = copy.deepcopy(bn)
x = torch.rand(16, 64, 32, 32)
bn.eval()
with torch.no_grad():
y = bn(x)
bn1.forward = types.MethodType(tbr_bn_forward_impl, bn1)
bn1.eval()
with torch.no_grad():
y1 = bn1(x)
print(torch.max(torch.abs(y-y1)))
print(torch.mean(torch.abs(y-y1)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment