Skip to content

Instantly share code, notes, and snippets.

@isears
Created May 2, 2024 10:25
Show Gist options
  • Save isears/3bc5d9193d20160c7e0bed1a9c9d7bff to your computer and use it in GitHub Desktop.
Save isears/3bc5d9193d20160c7e0bed1a9c9d7bff to your computer and use it in GitHub Desktop.
from captum.attr import IntegratedGradients
import pickle
import torch
with open("./tst_save.pkl", "rb") as f:
trained_model = pickle.load(f)
trained_model.eval()
X = torch.load("./X_test.pt")
X.requires_grad = True
attributor = IntegratedGradients(trained_model)
attributions = attributor.attribute(X, target=0)
assert attributions.shape == X.shape
torch.save(attributions, "./attributions.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment