Skip to content

Instantly share code, notes, and snippets.

@monk1337
Created July 3, 2022 13:16
Show Gist options
  • Save monk1337/4fd93bec6082a1038b09e3a7b68e7f5b to your computer and use it in GitHub Desktop.
Save monk1337/4fd93bec6082a1038b09e3a7b68e7f5b to your computer and use it in GitHub Desktop.
from torch_geometric.nn import Explainer, GCNConv, to_captum
from captum.attr import Saliency, IntegratedGradients
import matplotlib.pyplot as plt
def explain_output(model, num_edges, edge_index, node_feat,
node_id= 2,
device = 'cpu'):
captum_model = to_captum(model, mask_type='edge', output_idx=node_id)
edge_mask = torch.ones(num_edges, requires_grad=True, device=device)
ig = IntegratedGradients(captum_model)
ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=0,
additional_forward_args=(node_feat, edge_index),
internal_batch_size=1)
# Scale attributions to [0, 1]:
ig_attr_edge = ig_attr_edge.squeeze(0).abs()
ig_attr_edge /= ig_attr_edge.max()
# Visualize absolute values of attributions:
explainer = Explainer(model)
ax, G = explainer.visualize_subgraph(node_id, edge_index, ig_attr_edge)
plt.show()
# Node explainability
# ===================
captum_model = to_captum(model, mask_type='node', output_idx=node_id)
ig = IntegratedGradients(captum_model)
ig_attr_node = ig.attribute(node_feat.unsqueeze(0), target=0,
additional_forward_args=(edge_index),
internal_batch_size=1)
# Scale attributions to [0, 1]:
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1)
ig_attr_node /= ig_attr_node.max()
# Visualize absolute values of attributions:
ax, G = explainer.visualize_subgraph(node_id, edge_index, ig_attr_edge,
node_alpha=ig_attr_node)
plt.show()
# Node and edge explainability
# ============================
captum_model = to_captum(model, mask_type='node_and_edge',
output_idx=node_id)
ig = IntegratedGradients(captum_model)
ig_attr_node, ig_attr_edge = ig.attribute(
(node_feat.unsqueeze(0), edge_mask.unsqueeze(0)), target=0,
additional_forward_args=(edge_index), internal_batch_size=1)
# Scale attributions to [0, 1]:
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1)
ig_attr_node /= ig_attr_node.max()
ig_attr_edge = ig_attr_edge.squeeze(0).abs()
ig_attr_edge /= ig_attr_edge.max()
# Visualize absolute values of attributions:
ax, G = explainer.visualize_subgraph(node_id, edge_index, ig_attr_edge,
node_alpha=ig_attr_node)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment