Skip to content

Instantly share code, notes, and snippets.

@danesherbs
Created March 14, 2022 08:49
Show Gist options
  • Save danesherbs/f39a311aa9c3d14c90650a0f66d1ab2a to your computer and use it in GitHub Desktop.
Save danesherbs/f39a311aa9c3d14c90650a0f66d1ab2a to your computer and use it in GitHub Desktop.
A PyTorch hook that stores the state of a forward pass
class StatefulHook:
def __init__(self):
self.module = None
self.input = None
self.output = None
def __call__(self, module, input, output):
self.module = module
self.input = input
self.output = output
net = ... # some neural net
hook = StatefulHook()
net.my_layer.register_forward_hook(hook)
print(hook.input)
print(hook.output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment