Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active January 9, 2021 10:43
Show Gist options
  • Select an option

  • Save sadimanna/6e70ff56c58dc8ddb2e66763f59326a5 to your computer and use it in GitHub Desktop.

Select an option

Save sadimanna/6e70ff56c58dc8ddb2e66763f59326a5 to your computer and use it in GitHub Desktop.
class NewModel(nn.Module):
def __init__(self, output_layers, *args):
super().__init__(*args)
self.output_layers = output_layers
#print(self.output_layers)
self.selected_out = OrderedDict()
#PRETRAINED MODEL
self.pretrained = models.resnet50(pretrained=True)
self.fhooks = []
for i,l in enumerate(list(self.pretrained._modules.keys())):
if i in self.output_layers:
self.fhooks.append(getattr(self.pretrained,l).register_forward_hook(self.forward_hook(l)))
def forward_hook(self,layer_name):
def hook(module, input, output):
self.selected_out[layer_name] = output
return hook
def forward(self, x):
out = self.pretrained(x)
return out, self.selected_out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment