Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Last active March 10, 2022 23:55
Show Gist options
  • Save jamesr66a/ea64a220b76a60910bde8c7512b1f881 to your computer and use it in GitHub Desktop.
Save jamesr66a/ea64a220b76a60910bde8c7512b1f881 to your computer and use it in GitHub Desktop.
# a.py
import torch
class Foo(torch.nn.Module):
def forward(self, x):
return x + len(x)
# b.py
import torch
class Bar(torch.nn.Module):
def forward(self, x):
return x + len(x.shape)
# base.py
import torch
from a import Foo
from b import Bar
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = Foo()
self.b = Bar()
def forward(self, x):
return self.b(self.a(x))
mm = MyModule()
x = torch.randn(5, 3)
mm(x)
import torch.fx
# traced = torch.fx.symbolic_trace(mm)
"""
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
code = """
import torch.fx
torch.fx.wrap('len')
"""
compiled_code = compile(code, '<string>', 'exec')
for name, mod in mm.named_modules():
if not hasattr(mod, 'forward'):
continue
forward_fn = mod.forward
exec(compiled_code, forward_fn.__globals__)
traced = torch.fx.symbolic_trace(mm)
print(traced.code)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment