Created
July 3, 2022 13:16
-
-
Save monk1337/4fd93bec6082a1038b09e3a7b68e7f5b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch_geometric.nn import Explainer, GCNConv, to_captum | |
from captum.attr import Saliency, IntegratedGradients | |
import matplotlib.pyplot as plt | |
def explain_output(model, num_edges, edge_index, node_feat, | |
node_id= 2, | |
device = 'cpu'): | |
captum_model = to_captum(model, mask_type='edge', output_idx=node_id) | |
edge_mask = torch.ones(num_edges, requires_grad=True, device=device) | |
ig = IntegratedGradients(captum_model) | |
ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=0, | |
additional_forward_args=(node_feat, edge_index), | |
internal_batch_size=1) | |
# Scale attributions to [0, 1]: | |
ig_attr_edge = ig_attr_edge.squeeze(0).abs() | |
ig_attr_edge /= ig_attr_edge.max() | |
# Visualize absolute values of attributions: | |
explainer = Explainer(model) | |
ax, G = explainer.visualize_subgraph(node_id, edge_index, ig_attr_edge) | |
plt.show() | |
# Node explainability | |
# =================== | |
captum_model = to_captum(model, mask_type='node', output_idx=node_id) | |
ig = IntegratedGradients(captum_model) | |
ig_attr_node = ig.attribute(node_feat.unsqueeze(0), target=0, | |
additional_forward_args=(edge_index), | |
internal_batch_size=1) | |
# Scale attributions to [0, 1]: | |
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1) | |
ig_attr_node /= ig_attr_node.max() | |
# Visualize absolute values of attributions: | |
ax, G = explainer.visualize_subgraph(node_id, edge_index, ig_attr_edge, | |
node_alpha=ig_attr_node) | |
plt.show() | |
# Node and edge explainability | |
# ============================ | |
captum_model = to_captum(model, mask_type='node_and_edge', | |
output_idx=node_id) | |
ig = IntegratedGradients(captum_model) | |
ig_attr_node, ig_attr_edge = ig.attribute( | |
(node_feat.unsqueeze(0), edge_mask.unsqueeze(0)), target=0, | |
additional_forward_args=(edge_index), internal_batch_size=1) | |
# Scale attributions to [0, 1]: | |
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1) | |
ig_attr_node /= ig_attr_node.max() | |
ig_attr_edge = ig_attr_edge.squeeze(0).abs() | |
ig_attr_edge /= ig_attr_edge.max() | |
# Visualize absolute values of attributions: | |
ax, G = explainer.visualize_subgraph(node_id, edge_index, ig_attr_edge, | |
node_alpha=ig_attr_node) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment