Last active
September 17, 2024 01:36
-
-
Save davidefiocco/3e1a0ed030792230a33c726c61f6b3a5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 2, | |
"metadata": { | |
"language_info": { | |
"name": "python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"version": "3.7.4-final" | |
}, | |
"orig_nbformat": 2, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"npconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": 3, | |
"kernelspec": { | |
"name": "python37464bitbaseconda591eac30377d4dc3af76304e9e0933b9", | |
"display_name": "Python 3.7.4 64-bit ('base': conda)" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Interpretation of BertForSequenceClassification in captum" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": "I0304 15:15:00.931115 12960 file_utils.py:41] PyTorch version 1.4.0 available.\n" | |
} | |
], | |
"source": [ | |
"from transformers import BertTokenizer, BertForSequenceClassification, BertConfig\n", | |
"\n", | |
"from captum.attr import visualization as viz\n", | |
"from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients\n", | |
"from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer\n", | |
"\n", | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": "I0304 15:15:05.666696 12960 configuration_utils.py:254] loading configuration file ../model/config.json\nI0304 15:15:05.669722 12960 configuration_utils.py:292] Model config BertConfig {\n \"architectures\": [\n \"BertForSequenceClassification\"\n ],\n \"attention_probs_dropout_prob\": 0.1,\n \"bos_token_id\": null,\n \"do_sample\": false,\n \"eos_token_ids\": null,\n \"finetuning_task\": \"cola\",\n \"hidden_act\": \"gelu\",\n \"hidden_dropout_prob\": 0.1,\n \"hidden_size\": 768,\n \"id2label\": {\n \"0\": \"LABEL_0\",\n \"1\": \"LABEL_1\"\n },\n \"initializer_range\": 0.02,\n \"intermediate_size\": 3072,\n \"is_decoder\": false,\n \"label2id\": {\n \"LABEL_0\": 0,\n \"LABEL_1\": 1\n },\n \"layer_norm_eps\": 1e-12,\n \"length_penalty\": 1.0,\n \"max_length\": 20,\n \"max_position_embeddings\": 512,\n \"model_type\": \"bert\",\n \"num_attention_heads\": 12,\n \"num_beams\": 1,\n \"num_hidden_layers\": 12,\n \"num_labels\": 2,\n \"num_return_sequences\": 1,\n \"output_attentions\": false,\n \"output_hidden_states\": false,\n \"output_past\": true,\n \"pad_token_id\": null,\n \"pruned_heads\": {},\n \"repetition_penalty\": 1.0,\n \"temperature\": 1.0,\n \"top_k\": 50,\n \"top_p\": 1.0,\n \"torchscript\": false,\n \"type_vocab_size\": 2,\n \"use_bfloat16\": false,\n \"vocab_size\": 31116\n}\n\nI0304 15:15:05.673698 12960 modeling_utils.py:459] loading weights file ../model/pytorch_model.bin\nI0304 15:15:09.035054 12960 tokenization_utils.py:417] Model name '../model/' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). Assuming '../model/' is a path, a model identifier, or url to a directory containing tokenizer files.\nI0304 15:15:09.048053 12960 tokenization_utils.py:446] Didn't find file ../model/added_tokens.json. We won't load it.\nI0304 15:15:09.052233 12960 tokenization_utils.py:499] loading file ../model/vocab.txt\nI0304 15:15:09.056055 12960 tokenization_utils.py:499] loading file None\nI0304 15:15:09.058056 12960 tokenization_utils.py:499] loading file ../model/special_tokens_map.json\nI0304 15:15:09.060055 12960 tokenization_utils.py:499] loading file ../model/tokenizer_config.json\n" | |
} | |
], | |
"source": [ | |
"\n", | |
"# load model\n", | |
"model = BertForSequenceClassification.from_pretrained('../model/')\n", | |
"model.to(device)\n", | |
"model.eval()\n", | |
"model.zero_grad()\n", | |
"\n", | |
"# load tokenizer\n", | |
"tokenizer = BertTokenizer.from_pretrained('../model/')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def predict(inputs):\n", | |
" return model(inputs)[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ref_token_id = tokenizer.pad_token_id # A token used for generating token reference\n", | |
"sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.\n", | |
"cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):\n", | |
"\n", | |
" text_ids = tokenizer.encode(text, add_special_tokens=False)\n", | |
" # construct input token ids\n", | |
" input_ids = [cls_token_id] + text_ids + [sep_token_id]\n", | |
" # construct reference token ids \n", | |
" ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]\n", | |
"\n", | |
" return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)\n", | |
"\n", | |
"def construct_input_ref_token_type_pair(input_ids, sep_ind=0):\n", | |
" seq_len = input_ids.size(1)\n", | |
" token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)\n", | |
" ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1\n", | |
" return token_type_ids, ref_token_type_ids\n", | |
"\n", | |
"def construct_input_ref_pos_id_pair(input_ids):\n", | |
" seq_length = input_ids.size(1)\n", | |
" position_ids = torch.arange(seq_length, dtype=torch.long, device=device)\n", | |
" # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`\n", | |
" ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)\n", | |
"\n", | |
" position_ids = position_ids.unsqueeze(0).expand_as(input_ids)\n", | |
" ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)\n", | |
" return position_ids, ref_position_ids\n", | |
" \n", | |
"def construct_attention_mask(input_ids):\n", | |
" return torch.ones_like(input_ids)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def custom_forward(inputs):\n", | |
" preds = predict(inputs)\n", | |
" return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"text = \"These tests do not work as expected.\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)\n", | |
"token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)\n", | |
"position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)\n", | |
"attention_mask = construct_attention_mask(input_ids)\n", | |
"\n", | |
"indices = input_ids[0].detach().tolist()\n", | |
"all_tokens = tokenizer.convert_ids_to_tokens(indices)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "(tensor([[-3.4676, 3.5508]], grad_fn=<AddmmBackward>),)" | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model(input_ids)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "tensor([[-3.4676, 3.5508]], grad_fn=<AddmmBackward>)" | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"predict(input_ids)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "tensor([0.0009], grad_fn=<UnsqueezeBackward0>)" | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"custom_forward(input_ids)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"attributions, delta = lig.attribute(inputs=input_ids,\n", | |
" baselines=ref_input_ids,\n", | |
" return_convergence_delta=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "Question:These tests do not work as expected.\nPredicted Answer: 1, prob ungrammatical: 0.0008944187\n" | |
} | |
], | |
"source": [ | |
"score = predict(input_ids)\n", | |
"\n", | |
"print('Question: ', text)\n", | |
"print('Predicted Answer: ' + str(torch.argmax(score[0]).numpy()) + ', prob ungrammatical: ' + str(torch.softmax(score, dim = 1)[0][0].detach().numpy()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def summarize_attributions(attributions):\n", | |
" attributions = attributions.sum(dim=-1).squeeze(0)\n", | |
" attributions = attributions / torch.norm(attributions)\n", | |
" return attributions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"attributions_sum = summarize_attributions(attributions)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "Visualization For Score\n" | |
}, | |
{ | |
"data": { | |
"text/html": "<table width: 100%><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>0</b></text></td><td><text style=\"padding-right:2em\"><b>1 (0.00)</b></text></td><td><text style=\"padding-right:2em\"><b>These tests do not work as expected.</b></text></td><td><text style=\"padding-right:2em\"><b>0.73</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> These </font></mark><mark style=\"background-color: hsl(120, 75%, 68%); opacity:1.0; line-height:1.75\"><font color=\"black\"> tests </font></mark><mark style=\"background-color: hsl(120, 75%, 90%); opacity:1.0; line-height:1.75\"><font color=\"black\"> do </font></mark><mark style=\"background-color: hsl(120, 75%, 86%); opacity:1.0; line-height:1.75\"><font color=\"black\"> not </font></mark><mark style=\"background-color: hsl(120, 75%, 80%); opacity:1.0; line-height:1.75\"><font color=\"black\"> work </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> as </font></mark><mark style=\"background-color: hsl(0, 75%, 88%); opacity:1.0; line-height:1.75\"><font color=\"black\"> expected </font></mark><mark style=\"background-color: hsl(0, 75%, 83%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>", | |
"text/plain": "<IPython.core.display.HTML object>" | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# storing couple samples in an array for visualization purposes\n", | |
"score_vis = viz.VisualizationDataRecord(\n", | |
" attributions_sum,\n", | |
" torch.softmax(score, dim = 1)[0][0],\n", | |
" torch.argmax(torch.softmax(score, dim = 1)[0]),\n", | |
" 0,\n", | |
" text,\n", | |
" attributions_sum.sum(), \n", | |
" all_tokens,\n", | |
" delta)\n", | |
"\n", | |
"print('\\033[1m', 'Visualization For Score', '\\033[0m')\n", | |
"viz.visualize_text([score_vis])" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment