Skip to content

Instantly share code, notes, and snippets.

@maxidl
Last active February 15, 2021 13:48
Show Gist options
  • Save maxidl/144851af8b42d572e0a86f51f0641dfc to your computer and use it in GitHub Desktop.
Save maxidl/144851af8b42d572e0a86f51f0641dfc to your computer and use it in GitHub Desktop.
generate attributions for transformers using captum, but with batches instead of per instance for higher total throughput
model.to(device)
model.eval()
model.zero_grad()
def forward_func(inputs, attention_mask=None):
return model(inputs, attention_mask=attention_mask).logits
lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)
all_input_ids, all_ref_input_ids, all_attributions, all_pred_probs, all_pred_class, all_true_class, all_attr_class, all_attr_score, all_convergence_scores = ([] for i in range(9))
batch_size = 192 # ~11GB for 64 with bert-base-uncased
dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, pin_memory=True)
for batch in tqdm(dl, desc='Generating explanations'):
batch['reference_input_ids'] = batch['input_ids'].clone()
# keep cls and sep tokens, set all others to pad token
batch['reference_input_ids'][~((batch['reference_input_ids'] == tokenizer.cls_token_id) | (batch['reference_input_ids'] == tokenizer.sep_token_id))] = tokenizer.pad_token_id
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
logits = forward_func(input_ids, attention_mask=attention_mask)
probs = logits.softmax(1).detach().cpu()
pred_prob = probs.max(1).values
pred_class = probs.argmax(1)
all_input_ids.extend([x[x != tokenizer.pad_token_id].tolist() for x in input_ids.clone().detach().cpu()])
# all_tokens.append(tokens)
all_pred_probs.extend(pred_prob.tolist())
all_pred_class.extend(pred_class.tolist())
all_true_class.extend(batch['labels'].tolist())
reference_input_ids = batch['reference_input_ids'].to(device)
pred_class = pred_class.to(device)
attributions, delta = lig.attribute(inputs=input_ids, baselines=reference_input_ids, target=pred_class, return_convergence_delta=True, additional_forward_args=attention_mask, internal_batch_size=batch_size*2)
attributions_sum = summarize_attributions(attributions)
attributions_sum = attributions_sum.detach().cpu()
attributions_sum = [attr[torch.where(att_mask == 1)[0]] for attr, att_mask in zip(attributions_sum, attention_mask)]
reference_input_ids = reference_input_ids.detach().cpu()
ref_sep_token_indices = [torch.where(x == tokenizer.sep_token_id)[0] for x in reference_input_ids]
reference_input_ids = [x[:i+1] for x, i in zip(reference_input_ids, ref_sep_token_indices)]
all_ref_input_ids.extend(reference_input_ids)
all_attributions.extend(attributions_sum)
all_attr_class.extend(pred_class.detach().cpu().tolist())
all_attr_score.extend([attr.sum().item() for attr in attributions_sum])
all_convergence_scores.extend(delta.detach().cpu().tolist())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment