Created
May 29, 2020 19:40
-
-
Save ruslangrimov/cf01f5db03e185e8dcec157d8965fa90 to your computer and use it in GitHub Desktop.
Get outputs of each layer using hooks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SaveOutput: | |
def __init__(self): | |
self.outputs = [] | |
def __call__(self, module, module_in, module_out): | |
self.outputs.append(module_out) | |
def clear(self): | |
self.outputs = [] | |
save_output = SaveOutput() | |
hook_handles = [] | |
for layer in model.modules(): | |
if isinstance(layer, torch.nn.modules.conv.Conv2d): | |
handle = layer.register_forward_hook(save_output) | |
hook_handles.append(handle) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment