Skip to content

Instantly share code, notes, and snippets.

@fg91
Last active August 11, 2020 18:54
Show Gist options
  • Save fg91/e5b5d49dca03e0846c667daf5aa716a5 to your computer and use it in GitHub Desktop.
Save fg91/e5b5d49dca03e0846c667daf5aa716a5 to your computer and use it in GitHub Desktop.
class SaveFeatures():
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.features = torch.tensor(output,requires_grad=True).cuda()
def close(self):
self.hook.remove()
@rcalfredson
Copy link

Thanks for your article about neural net visualization! For more recent versions of PyTorch, I found it is necessary to change line 5 to self.features = output for the back-propagation to succeed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment