Created
August 3, 2018 00:06
-
-
Save Tushar-N/680633ec18f5cb4b47933da7d10902af to your computer and use it in GitHub Desktop.
Pytorch code to save activations for specific layers over an entire dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as tmodels | |
from functools import partial | |
import collections | |
# dummy data: 10 batches of images with batch size 16 | |
dataset = [torch.rand(16,3,224,224).cuda() for _ in range(10)] | |
# network: a resnet50 | |
net = tmodels.resnet50(pretrained=True).cuda() | |
# a dictionary that keeps saving the activations as they come | |
activations = collections.defaultdict(list) | |
def save_activation(name, mod, inp, out): | |
activations[name].append(out.cpu()) | |
# Registering hooks for all the Conv2d layers | |
# Note: Hooks are called EVERY TIME the module performs a forward pass. For modules that are | |
# called repeatedly at different stages of the forward pass (like RELUs), this will save different | |
# activations. Editing the forward pass code to save activations is the way to go for these cases. | |
for name, m in net.named_modules(): | |
if type(m)==nn.Conv2d: | |
# partial to assign the layer name to each hook | |
m.register_forward_hook(partial(save_activation, name)) | |
# forward pass through the full dataset | |
for batch in dataset: | |
out = net(batch) | |
# concatenate all the outputs we saved to get the the activations for each layer for the whole dataset | |
activations = {name: torch.cat(outputs, 0) for name, outputs in activations.items()} | |
# just print out the sizes of the saved activations as a sanity check | |
for k,v in activations.items(): | |
print (k, v.size()) |
Great, consider writing a blog post as @jzm0144 suggested. Also including some examples of what
For modules that are
# called repeatedly at different stages of the forward pass (like RELUs), this will save different
# activations.
this part describes.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice work. Perhaps you should consider writing a small blog post on how you use the hook method. Because, before landing on your code, I spend a lot of time finding a quick and straightforward method.