Skip to content

Instantly share code, notes, and snippets.

@xmfbit
Created October 18, 2021 18:25
Show Gist options
  • Save xmfbit/15305999bf5f4a13018ff2d947a6db6c to your computer and use it in GitHub Desktop.
Save xmfbit/15305999bf5f4a13018ff2d947a6db6c to your computer and use it in GitHub Desktop.
example of torch.fx, showing how to fuse the conv-bn module groups
""" Fuse conv-bn pattern in torch.Module, an example for torch.fx
see: https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html
"""
import copy
from typing import Tuple, Dict, Any
import torch
import torch.fx as fx
import torch.nn as nn
from ipdb import set_trace
# helper functions to fuse the conv and bn
# nothing special, just math operations
def fuse_conv_bn_eval(conv, bn):
"""
Given a conv Module `A` and an batch_norm module `B`, returns a conv
module `C` such that C(x) == B(A(x)) in inference mode.
"""
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
fused_conv.weight, fused_conv.bias = \
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
return fused_conv
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
# Module part: create a nn.Module with conv-bn pattern inside
# notice that the conv-bn could be very flexible:
# - used in nn.Module directly
# - wrapped bn
# - nested style with a Sequential container
class WrappedBatchNorm(nn.Module):
def __init__(self):
super().__init__()
self.mod = nn.BatchNorm2d(1)
def forward(self, x):
return self.mod(x)
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.bn1 = nn.BatchNorm2d(1)
self.conv2 = nn.Conv2d(1, 1, 1)
self.nested = nn.Sequential(
nn.BatchNorm2d(1),
nn.Conv2d(1, 1, 1),
)
self.wrapped = WrappedBatchNorm()
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.nested(x)
x = self.wrapped(x)
return x
# create an instance of the module
model = M().eval()
# Let's start!
def _parent_name(target : str) -> Tuple[str, str]:
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert(isinstance(node.target, str))
# this is for nested modules
parent_name, name = _parent_name(node.target)
print(f'Modules[{parent_name}].{name} <- {new_module.__class__.__name__}')
setattr(modules[parent_name], name, new_module)
def fuse(model: nn.Module) -> nn.Module:
""" Fuse the conv and bn, where magic happens
"""
# get the graph representation of the model
model = copy.deepcopy(model)
fx_model: fx.GraphModule = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())
# Each `GraphModule` has a `Graph` associated with it
# The `Graph` itself is represented as a list of `Node` objects.
# To iterate the `Graph`, we need iterate the `Node`s
for node in fx_model.graph.nodes:
# only consider the nodes of `call_module` type
if node.op != 'call_module':
continue
# For call sites, `Node.target` represents the module/function/method
# that's being called.
# Here, we check `Node.target` to see if it's a batch norm module,
# and then check `Node.args[0].target` to see if the input `Node` is
# a convolution.
cur_module = modules[node.target]
if isinstance(cur_module, nn.BatchNorm2d):
prev_module = modules[node.args[0].target]
if isinstance(prev_module, nn.Conv2d):
# find conv-bn pattern
if len(node.args[0].users) > 1:
# Output of conv is used by other nodes
continue
fused_conv = fuse_conv_bn_eval(prev_module, cur_module)
replace_node_module(node.args[0], modules, fused_conv)
# As we've folded the batch norm into the conv, we need to
# replace all uses of the batch norm with the conv.
node.replace_all_uses_with(node.args[0])
# Now that all uses of the batch norm have been replaced, we can
# safely remove the batch norm.
fx_model.graph.erase_node(node)
fx_model.graph.lint()
# After we've modified our graph, we need to recompile our graph in order
# to keep the generated code in sync.
fx_model.recompile()
return fx_model
fused_model = fuse(model)
print(f'The `forward` code after fusion:{fused_model.code}')
# check the output
inp = torch.randn(5, 1, 1, 1)
if torch.allclose(fused_model(inp), model(inp)):
print('Fuse successfully')
else:
print('Fail to fuse, the diff is too large')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment