Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created November 11, 2021 19:13
Show Gist options
  • Save jamesr66a/d36bd887459e9089910120bb9a489ff5 to your computer and use it in GitHub Desktop.
Save jamesr66a/d36bd887459e9089910120bb9a489ff5 to your computer and use it in GitHub Desktop.
import torch
import torch.fx
import operator
from torch.fx.node import map_arg
from torch.fx.passes.shape_prop import ShapeProp
def binary_mapping(op):
def f(a, b):
return op(a, b)
return f
decomposition_rules = {}
binary_decompositions = [
(operator.add, torch.add),
]
for old, new in binary_decompositions:
decomposition_rules[old] = binary_mapping(new)
def decompose(model: torch.nn.Module, sample_inputs) -> torch.nn.Module:
# Run it multiple times so we converge to a fixed point.
for _ in range(5):
model = torch.fx.symbolic_trace(model)
ShapeProp(model).propagate(*sample_inputs)
new_graph = torch.fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in model.graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
proxy_args = map_arg(node.args, lambda n: torch.fx.Proxy(env[n.name], tracer))
proxy_kwargs = map_arg(node.kwargs, lambda n: torch.fx.Proxy(env[n.name], tracer))
new_node = decomposition_rules[node.target](*proxy_args, **proxy_kwargs).node
env[node.name] = new_node
else:
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
model = torch.fx.GraphModule(model, new_graph)
return model
class SimpleAddModule(torch.nn.Module):
def __init__(self):
super(SimpleAddModule, self).__init__()
def forward(self, x, y):
return x + y
x = torch.randn((4, 4))
y = torch.randn((4, 4))
model = SimpleAddModule()
traced_model = decompose(model, (x, y))
print(traced_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment