Skip to content

Instantly share code, notes, and snippets.

@maxidl
Created February 10, 2022 21:49
Show Gist options
  • Save maxidl/755da19156131f678d330a12022a47b1 to your computer and use it in GitHub Desktop.
Save maxidl/755da19156131f678d330a12022a47b1 to your computer and use it in GitHub Desktop.
def get_simple_gradient_expl(model, images, targets, absolute=False):
images.requires_grad = True
outputs = model(images)
outputs = outputs.gather(1, targets.unsqueeze(1))
grad = torch.autograd.grad(torch.unbind(outputs), images, create_graph=True)[0] # create_graph=True for second order derivative
expl = grad.abs() if absolute else grad
return expl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment