Created
October 18, 2021 18:25
-
-
Save xmfbit/15305999bf5f4a13018ff2d947a6db6c to your computer and use it in GitHub Desktop.
example of torch.fx, showing how to fuse the conv-bn module groups
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Fuse conv-bn pattern in torch.Module, an example for torch.fx | |
see: https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html | |
""" | |
import copy | |
from typing import Tuple, Dict, Any | |
import torch | |
import torch.fx as fx | |
import torch.nn as nn | |
from ipdb import set_trace | |
# helper functions to fuse the conv and bn | |
# nothing special, just math operations | |
def fuse_conv_bn_eval(conv, bn): | |
""" | |
Given a conv Module `A` and an batch_norm module `B`, returns a conv | |
module `C` such that C(x) == B(A(x)) in inference mode. | |
""" | |
assert(not (conv.training or bn.training)), "Fusion only for eval!" | |
fused_conv = copy.deepcopy(conv) | |
fused_conv.weight, fused_conv.bias = \ | |
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, | |
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) | |
return fused_conv | |
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): | |
if conv_b is None: | |
conv_b = torch.zeros_like(bn_rm) | |
if bn_w is None: | |
bn_w = torch.ones_like(bn_rm) | |
if bn_b is None: | |
bn_b = torch.zeros_like(bn_rm) | |
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) | |
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) | |
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b | |
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) | |
# Module part: create a nn.Module with conv-bn pattern inside | |
# notice that the conv-bn could be very flexible: | |
# - used in nn.Module directly | |
# - wrapped bn | |
# - nested style with a Sequential container | |
class WrappedBatchNorm(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.mod = nn.BatchNorm2d(1) | |
def forward(self, x): | |
return self.mod(x) | |
class M(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d(1, 1, 1) | |
self.bn1 = nn.BatchNorm2d(1) | |
self.conv2 = nn.Conv2d(1, 1, 1) | |
self.nested = nn.Sequential( | |
nn.BatchNorm2d(1), | |
nn.Conv2d(1, 1, 1), | |
) | |
self.wrapped = WrappedBatchNorm() | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.conv2(x) | |
x = self.nested(x) | |
x = self.wrapped(x) | |
return x | |
# create an instance of the module | |
model = M().eval() | |
# Let's start! | |
def _parent_name(target : str) -> Tuple[str, str]: | |
""" | |
Splits a qualname into parent path and last atom. | |
For example, `foo.bar.baz` -> (`foo.bar`, `baz`) | |
""" | |
*parent, name = target.rsplit('.', 1) | |
return parent[0] if parent else '', name | |
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): | |
assert(isinstance(node.target, str)) | |
# this is for nested modules | |
parent_name, name = _parent_name(node.target) | |
print(f'Modules[{parent_name}].{name} <- {new_module.__class__.__name__}') | |
setattr(modules[parent_name], name, new_module) | |
def fuse(model: nn.Module) -> nn.Module: | |
""" Fuse the conv and bn, where magic happens | |
""" | |
# get the graph representation of the model | |
model = copy.deepcopy(model) | |
fx_model: fx.GraphModule = fx.symbolic_trace(model) | |
modules = dict(fx_model.named_modules()) | |
# Each `GraphModule` has a `Graph` associated with it | |
# The `Graph` itself is represented as a list of `Node` objects. | |
# To iterate the `Graph`, we need iterate the `Node`s | |
for node in fx_model.graph.nodes: | |
# only consider the nodes of `call_module` type | |
if node.op != 'call_module': | |
continue | |
# For call sites, `Node.target` represents the module/function/method | |
# that's being called. | |
# Here, we check `Node.target` to see if it's a batch norm module, | |
# and then check `Node.args[0].target` to see if the input `Node` is | |
# a convolution. | |
cur_module = modules[node.target] | |
if isinstance(cur_module, nn.BatchNorm2d): | |
prev_module = modules[node.args[0].target] | |
if isinstance(prev_module, nn.Conv2d): | |
# find conv-bn pattern | |
if len(node.args[0].users) > 1: | |
# Output of conv is used by other nodes | |
continue | |
fused_conv = fuse_conv_bn_eval(prev_module, cur_module) | |
replace_node_module(node.args[0], modules, fused_conv) | |
# As we've folded the batch norm into the conv, we need to | |
# replace all uses of the batch norm with the conv. | |
node.replace_all_uses_with(node.args[0]) | |
# Now that all uses of the batch norm have been replaced, we can | |
# safely remove the batch norm. | |
fx_model.graph.erase_node(node) | |
fx_model.graph.lint() | |
# After we've modified our graph, we need to recompile our graph in order | |
# to keep the generated code in sync. | |
fx_model.recompile() | |
return fx_model | |
fused_model = fuse(model) | |
print(f'The `forward` code after fusion:{fused_model.code}') | |
# check the output | |
inp = torch.randn(5, 1, 1, 1) | |
if torch.allclose(fused_model(inp), model(inp)): | |
print('Fuse successfully') | |
else: | |
print('Fail to fuse, the diff is too large') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment