Last active
June 20, 2022 20:41
-
-
Save rpryzant/a2324dd608c63f1637b1e36a1ffce46d to your computer and use it in GitHub Desktop.
Integrated gradients wrapper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
USAGE | |
model = build_model() | |
attributor = Attributor(model, target_class=1, tokenizer=tokenizer) | |
... | |
# viz = interactive vizualization that you can dump into a file and look at in a web browser | |
# t2a = map of token to its attribution score | |
viz, t2a, attrs, y_prob, y_hat = attributor.attr_and_visualize( | |
batch['input_ids'], batch['labels']) | |
with open('vizualization.html', 'w') as f: | |
f.write('\n'.join(viz)) | |
""" | |
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients | |
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer | |
from captum.attr import visualization as viz | |
class Attributor: | |
def __init__(self, model, target_class, tokenizer): | |
""" TODO generalize to multiclass """ | |
self.model = model | |
self.target_class = target_class | |
self.tokenizer = tokenizer | |
self.fwd_fn = self.build_forward_fn(target_class) | |
self.lig = LayerIntegratedGradients(self.fwd_fn, self.model.distilbert.embeddings) | |
def attribute(self, input_ids): | |
ref_ids = [[x if x in [101, 102] else 0 for x in input_ids[0]]] | |
attribution, delta = self.lig.attribute( | |
inputs=torch.tensor(input_ids).cuda() if CUDA else torch.tensor(input_ids), | |
baselines=torch.tensor(ref_ids).cuda() if CUDA else torch.tensor(ref_ids), | |
n_steps=25, | |
internal_batch_size=5, | |
return_convergence_delta=True) | |
attribution_sum = self.summarize(attribution) | |
return attribution_sum, delta | |
def attr_and_visualize(self, input_ids, label): | |
attr_sum, delta = self.attribute(input_ids) | |
y_prob = self.fwd_fn(input_ids) | |
pred_class = 1 if y_prob.data[0] > 0.5 else 0 | |
if CUDA: | |
input_ids = input_ids.cpu().numpy()[0] | |
label = label.cpu().item() | |
attr_sum = attr_sum.cpu().numpy() | |
y_prob = y_prob.cpu().item() | |
else: | |
input_ids = input_ids.numpy()[0] | |
label = label.item() | |
attr_sum = attr_sum.numpy() | |
y_prob = y_prob.item() | |
tokens = self.tokenizer.convert_ids_to_tokens(input_ids) | |
record = viz.VisualizationDataRecord( | |
attr_sum, | |
y_prob, | |
pred_class, | |
label, | |
self.target_class, | |
attr_sum.sum(), | |
tokens, | |
delta) | |
tok2attr = defaultdict(list) | |
for tok, attr in zip(tokens, attr_sum): | |
tok2attr[tok].append(attr) | |
html = viz.visualize_text([record]) | |
return html.data, tok2attr, attr_sum, y_prob, pred_class | |
def build_forward_fn(self, label_dim): | |
def custom_forward(inputs): | |
preds = self.model(inputs)[0] | |
return torch.softmax(preds, dim=1)[:, label_dim] | |
return custom_forward | |
def summarize(self, attributions): | |
""" sum across each embedding dim and normalize """ | |
attributions = attributions.sum(dim=-1).squeeze(0) | |
attributions = attributions / torch.norm(attributions) | |
return attributions | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment