Skip to content

Instantly share code, notes, and snippets.

@tvogels
Last active August 26, 2021 16:39
Show Gist options
  • Save tvogels/8dd148d96e8569bdd3bc0b257e7f54df to your computer and use it in GitHub Desktop.
Save tvogels/8dd148d96e8569bdd3bc0b257e7f54df to your computer and use it in GitHub Desktop.
capture module inputs in PyTorch
import torch
from contextlib import contextmanager
_capture_inputs_dict = {}
_capture_inputs_active = False
@contextmanager
def capture_inputs():
global _capture_inputs_dict
global _capture_inputs_active
assert not _capture_inputs_active, "only one at a time allowed"
prev_dict = _capture_inputs_dict
_capture_inputs_active = True
_capture_inputs_dict = {}
try:
yield _capture_inputs_dict
finally:
_capture_inputs_active = False
_capture_inputs_dict = prev_dict
class CaptureModule(torch.nn.Module):
def __init__(self, module, name):
super().__init__()
self.module = module
self.name = name
def forward(self, *args, **kwargs):
if _capture_inputs_active:
_capture_inputs_dict[self.name] = [a.detach() for a in args]
return self.module(*args, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment