Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active July 31, 2021 10:18
Show Gist options
  • Save sadimanna/df7fae2b7c14d62b40911e8dc60300af to your computer and use it in GitHub Desktop.
Save sadimanna/df7fae2b7c14d62b40911e8dc60300af to your computer and use it in GitHub Desktop.
class GradCamModel(nn.Module):
def __init__(self):
super().__init__()
self.gradients = None
self.tensorhook = []
self.layerhook = []
self.selected_out = None
#PRETRAINED MODEL
self.pretrained = models.resnet50(pretrained=True)
self.layerhook.append(self.pretrained.layer4.register_forward_hook(self.forward_hook()))
for p in self.pretrained.parameters():
p.requires_grad = True
def activations_hook(self,grad):
self.gradients = grad
def get_act_grads(self):
return self.gradients
def forward_hook(self):
def hook(module, inp, out):
self.selected_out = out
self.tensorhook.append(out.register_hook(self.activations_hook))
return hook
def forward(self,x):
out = self.pretrained(x)
return out, self.selected_out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment