Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Created December 19, 2022 11:14
Show Gist options
  • Save chenyaofo/78615a25f4bef43344e538fde22e9cbc to your computer and use it in GitHub Desktop.
Save chenyaofo/78615a25f4bef43344e538fde22e9cbc to your computer and use it in GitHub Desktop.
Question about TBR.
import torch
import torch.nn as nn
from torchvision.models import resnet50
import types
import copy
def tbr_bn_forward_impl(self: nn.BatchNorm2d, x: torch.Tensor):
# print("x.shape", x.shape)
batch_mean = x.mean(dim=(0, 2, 3), keepdim=True)
batch_var = x.std(dim=(0, 2, 3), keepdim=True, unbiased=True) + self.eps
if self.running_mean is None:
print("replace")
self.running_mean, self.running_var = batch_mean.clone().detach(), batch_var.clone().detach()
else:
print("no-replace")
self.running_mean, self.running_var = self.running_mean.view(1, -1, 1, 1), self.running_var.view(1, -1, 1, 1)
r = batch_var.detach() / self.running_var
d = (batch_mean.detach() - self.running_mean) / self.running_var
# import ipdb; ipdb.set_trace()
x = ((x - batch_mean) / batch_var) * r + d
# x = (x - self.running_mean) / self.running_var
self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)
self.running_var += self.momentum * (batch_var.detach() - self.running_var)
x = self.weight.view(1, -1, 1, 1) * x + self.bias.view(1, -1, 1, 1)
return x
def normal_bn_forward_impl(self: nn.BatchNorm2d, x: torch.Tensor):
x = ((x - self.running_mean.view(1, -1, 1, 1)) / torch.sqrt(self.running_var.view(1, -1, 1, 1))+self.eps)
x = self.weight.view(1, -1, 1, 1) * x + self.bias.view(1, -1, 1, 1)
return x
model = resnet50(pretrained=True)
bn = model.bn1
print(bn.weight)
bn1 = copy.deepcopy(bn)
x = torch.rand(16,64,32,32)
bn.eval()
with torch.no_grad():
y = bn(x)
bn1.forward = types.MethodType(tbr_bn_forward_impl, bn1)
bn1.eval()
with torch.no_grad():
y1 = bn1(x)
print(torch.max(torch.abs(y-y1)))
print(torch.mean(torch.abs(y-y1)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment