Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created January 24, 2022 18:13
Show Gist options
  • Save jamesr66a/5d4e22e1008f76c8fc97d350b0dddceb to your computer and use it in GitHub Desktop.
Save jamesr66a/5d4e22e1008f76c8fc97d350b0dddceb to your computer and use it in GitHub Desktop.
import torch
import torch.fx
class Foo(torch.nn.Module):
def forward(self, x):
return torch.relu(x)
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.f = Foo()
def forward(self, x):
return self.f(x)
b = Bar()
@torch.fx.wrap
def print_here(output):
print('here')
return output
def here(self, input, output):
return print_here(output)
b.f.register_forward_hook(here)
print('** run')
b(torch.randn(5, 3))
"""
** run
here
"""
print('** trace')
traced = torch.fx.symbolic_trace(b)
# Note there is no output during tracing, as the call to `print_here` is
# directly emitted as a call in the graph, rather than tracing through it
"""
** trace
"""
print('** run trace')
traced(torch.randn(5, 3))
"""
** run trace
here
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment