Last active
January 10, 2021 21:17
-
-
Save sandeepkumar-skb/186f5e5c1549fd88cbd606ea2da44b6b to your computer and use it in GitHub Desktop.
Folding BN into convolution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import copy | |
import torchvision.models as models | |
class BN_Folder(): | |
def fold(self, model): | |
mymodel = copy.deepcopy(model) | |
mymodel.eval() | |
model_keys = list(mymodel._modules.keys()) # Get the module names for each of the module | |
prev = None | |
for name in model_keys: | |
if len(mymodel._modules[name]._modules) > 0: | |
mymodel._modules[name] = self.fold(mymodel._modules[name]) | |
elif (isinstance(mymodel._modules[name], nn.BatchNorm2d) and | |
isinstance(mymodel._modules[prev], nn.Conv2d)): | |
folded_conv = self.fold_bn(mymodel._modules[prev], mymodel._modules[name]) | |
mymodel._modules.pop(name) | |
mymodel._modules[prev] = folded_conv | |
prev = name | |
return mymodel | |
def fold_bn(self, conv, bn): | |
folded_conv = copy.deepcopy(conv) | |
conv_w = conv.weight | |
conv_b = conv.bias | |
bn_rv = bn.running_var | |
bn_rm = bn.running_mean | |
bn_eps = bn.eps | |
bn_w = bn.weight | |
bn_b = bn.bias | |
folded_conv.weight, folded_conv.bias = self.fold_bn_util(conv_w, conv_b, bn_w, bn_b, bn_rv, bn_rm, bn_eps) | |
return folded_conv | |
def fold_bn_util(self, conv_w, conv_b, bn_w, bn_b, bn_rv, bn_rm, bn_eps): | |
if not conv_b: | |
conv_b = torch.zeros_like(bn_b) | |
bn_rv = torch.rsqrt(bn_rv + bn_eps) | |
folded_w = conv_w * (bn_w * bn_rv).view(-1, 1,1,1) | |
folded_b = (conv_b - bn_rm)* bn_w * bn_rv + bn_b | |
return torch.nn.Parameter(folded_w), torch.nn.Parameter(folded_b) | |
if __name__ == "__main__": | |
rn18 = models.resnet18(pretrained=True) | |
bn_folder = BN_Folder() | |
new_mod = bn_folder.fold(rn18) | |
print(new_mod) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import copy | |
import time | |
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv = torch.nn.Conv2d(32,64,(3,3)) | |
self.bn = torch.nn.BatchNorm2d(64) | |
def forward(self, inp): | |
x = self.conv(inp) | |
out = self.bn(x) | |
return out | |
if __name__ == "__main__": | |
model = Model() | |
model.eval().cuda() | |
with torch.no_grad(): | |
x = torch.randn((64,32,56,56), device='cuda') | |
z = model(x) | |
conv = model.conv | |
bn = model.bn | |
conv_w = conv.weight | |
conv_b = conv.bias | |
bn_w = bn.weight | |
bn_b = bn.bias | |
bn_rv = bn.running_var | |
bn_rm = bn.running_mean | |
bn_eps = bn.eps | |
bn_rv = torch.rsqrt(bn_rv + bn_eps) | |
folded_w = conv_w * (bn_w*bn_rv).view(-1,1,1,1) | |
folded_b = (conv_b - bn_rm) * bn_w*bn_rv + bn_b | |
folded_conv = copy.deepcopy(conv) | |
folded_conv.weight = torch.nn.Parameter(folded_w) | |
folded_conv.bias = torch.nn.Parameter(folded_b) | |
y = folded_conv(x) | |
print(torch.sum(y-z)) | |
torch.cuda.synchronize() | |
start = time.time() | |
num_iter = 10000 | |
for _ in range(num_iter): | |
z = folded_conv(x) | |
torch.cuda.synchronize() | |
print("With conv-bn folding: {:.2f}ms".format((time.time() - start)*1000/num_iter)) | |
start = time.time() | |
for _ in range(num_iter): | |
z = model(x) | |
torch.cuda.synchronize() | |
print("Without conv-bn folding: {:.2f}ms".format((time.time() - start)*1000/num_iter)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Calculated the results over 10k iterations.