Skip to content

Instantly share code, notes, and snippets.

@gasiort
Last active February 28, 2018 00:18
Show Gist options
  • Save gasiort/72a4e20d2c9b5f93606599816eef8e3a to your computer and use it in GitHub Desktop.
Save gasiort/72a4e20d2c9b5f93606599816eef8e3a to your computer and use it in GitHub Desktop.
guided backprop. w pytorchu
class GuidedBackprop(torch.nn.Module):
def __init__(self, model, n_classes):
super().__init__()
self.model = model
self.n_classes = n_classes
self.gradients = None
self.model.eval()
self.update_relus()
self.hook_first_layer()
def hook_first_layer(self):
"""
hook ktory zapisuje ostatni gradient do atrybutu
"""
def hook_function(module, grad_in, grad_out):
self.gradients = grad_in[0]
first_layer = self.model.features[0]
first_layer.register_backward_hook(hook_function)
def update_relus(self):
"""
tutaj dla kazdej f. aktywacji dodajemy funkcje ktora dziala jak ReLU
dla gradientu przy propagacji wstecznej - to jedyna idea Guided Backprop.
"""
def relu_hook_function(module, grad_in, grad_out):
if isinstance(module, ReLU):
return (torch.clamp(grad_in[0], min=0.0),)
for pos, module in self.model.features._modules.items():
if isinstance(module, ReLU):
module.register_backward_hook(relu_hook_function)
def forward(self, img, img_class):
# inferencja obrazka - forward pass
model_output = self.model(img)
# generujemy prawdziwa odpowiedz dla tego obrazka - one-hot
one_hot = torch.zeros(1, self.n_classes)
one_hot[0][img_class] = 1
# propagacja wsteczna
self.model.zero_grad()
model_output.backward(gradient=one_hot)
return self.gradients.data.numpy()[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment