Last active
February 15, 2021 13:48
-
-
Save maxidl/144851af8b42d572e0a86f51f0641dfc to your computer and use it in GitHub Desktop.
generate attributions for transformers using captum, but with batches instead of per instance for higher total throughput
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model.to(device) | |
model.eval() | |
model.zero_grad() | |
def forward_func(inputs, attention_mask=None): | |
return model(inputs, attention_mask=attention_mask).logits | |
lig = LayerIntegratedGradients(forward_func, model.bert.embeddings) | |
all_input_ids, all_ref_input_ids, all_attributions, all_pred_probs, all_pred_class, all_true_class, all_attr_class, all_attr_score, all_convergence_scores = ([] for i in range(9)) | |
batch_size = 192 # ~11GB for 64 with bert-base-uncased | |
dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, pin_memory=True) | |
for batch in tqdm(dl, desc='Generating explanations'): | |
batch['reference_input_ids'] = batch['input_ids'].clone() | |
# keep cls and sep tokens, set all others to pad token | |
batch['reference_input_ids'][~((batch['reference_input_ids'] == tokenizer.cls_token_id) | (batch['reference_input_ids'] == tokenizer.sep_token_id))] = tokenizer.pad_token_id | |
input_ids = batch['input_ids'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
logits = forward_func(input_ids, attention_mask=attention_mask) | |
probs = logits.softmax(1).detach().cpu() | |
pred_prob = probs.max(1).values | |
pred_class = probs.argmax(1) | |
all_input_ids.extend([x[x != tokenizer.pad_token_id].tolist() for x in input_ids.clone().detach().cpu()]) | |
# all_tokens.append(tokens) | |
all_pred_probs.extend(pred_prob.tolist()) | |
all_pred_class.extend(pred_class.tolist()) | |
all_true_class.extend(batch['labels'].tolist()) | |
reference_input_ids = batch['reference_input_ids'].to(device) | |
pred_class = pred_class.to(device) | |
attributions, delta = lig.attribute(inputs=input_ids, baselines=reference_input_ids, target=pred_class, return_convergence_delta=True, additional_forward_args=attention_mask, internal_batch_size=batch_size*2) | |
attributions_sum = summarize_attributions(attributions) | |
attributions_sum = attributions_sum.detach().cpu() | |
attributions_sum = [attr[torch.where(att_mask == 1)[0]] for attr, att_mask in zip(attributions_sum, attention_mask)] | |
reference_input_ids = reference_input_ids.detach().cpu() | |
ref_sep_token_indices = [torch.where(x == tokenizer.sep_token_id)[0] for x in reference_input_ids] | |
reference_input_ids = [x[:i+1] for x, i in zip(reference_input_ids, ref_sep_token_indices)] | |
all_ref_input_ids.extend(reference_input_ids) | |
all_attributions.extend(attributions_sum) | |
all_attr_class.extend(pred_class.detach().cpu().tolist()) | |
all_attr_score.extend([attr.sum().item() for attr in attributions_sum]) | |
all_convergence_scores.extend(delta.detach().cpu().tolist()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment