Created
January 24, 2022 18:13
-
-
Save jamesr66a/5d4e22e1008f76c8fc97d350b0dddceb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.fx | |
class Foo(torch.nn.Module): | |
def forward(self, x): | |
return torch.relu(x) | |
class Bar(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.f = Foo() | |
def forward(self, x): | |
return self.f(x) | |
b = Bar() | |
@torch.fx.wrap | |
def print_here(output): | |
print('here') | |
return output | |
def here(self, input, output): | |
return print_here(output) | |
b.f.register_forward_hook(here) | |
print('** run') | |
b(torch.randn(5, 3)) | |
""" | |
** run | |
here | |
""" | |
print('** trace') | |
traced = torch.fx.symbolic_trace(b) | |
# Note there is no output during tracing, as the call to `print_here` is | |
# directly emitted as a call in the graph, rather than tracing through it | |
""" | |
** trace | |
""" | |
print('** run trace') | |
traced(torch.randn(5, 3)) | |
""" | |
** run trace | |
here | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment