Last active
March 31, 2020 22:18
-
-
Save davidefiocco/40a1395e895174a4e4d3ed424a5d388a to your computer and use it in GitHub Desktop.
Trying captum interpretation on pretrained sentiment classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"orig_nbformat": 2, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"npconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": 3, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"colab": { | |
"name": "Interpretation.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "UFESEuEgbUDD", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Interpretation of BertForSequenceClassification in captum\n", | |
"\n", | |
"In this notebook we use Captum to interpret a BERT sentiment classifier finetuned on the imdb dataset https://huggingface.co/lvwerra/bert-imdb " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EJ51JAxHbghp", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Install dependencies\n", | |
"!pip install transformers\n", | |
"!pip install captum" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CS9Kaz8ubUDG", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from transformers import BertTokenizer, BertForSequenceClassification, BertConfig\n", | |
"from captum.attr import visualization as viz\n", | |
"from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients\n", | |
"from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer\n", | |
"import torch\n", | |
"import matplotlib.pyplot as plt" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "P1yl1gdvbUDS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3U5XDt1Gb73t", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Get model and config files from https://huggingface.co/lvwerra/bert-imdb\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/special_tokens_map.json\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/tokenizer_config.json\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/training_args.bin\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/vocab.txt" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "X-nyyq_tbUDa", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# load model\n", | |
"model = BertForSequenceClassification.from_pretrained('./model')\n", | |
"model.to(device)\n", | |
"model.eval()\n", | |
"model.zero_grad()\n", | |
"\n", | |
"# load tokenizer\n", | |
"tokenizer = BertTokenizer.from_pretrained('./model')" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JUMsvUOTbUDi", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def predict(inputs):\n", | |
" return model(inputs)[0]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "SIbauwGbbUDo", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"ref_token_id = tokenizer.pad_token_id # A token used for generating token reference\n", | |
"sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.\n", | |
"cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "mcnTCNUFbUD1", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):\n", | |
"\n", | |
" text_ids = tokenizer.encode(text, add_special_tokens=False)\n", | |
" # construct input token ids\n", | |
" input_ids = [cls_token_id] + text_ids + [sep_token_id]\n", | |
" # construct reference token ids \n", | |
" ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]\n", | |
"\n", | |
" return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)\n", | |
"\n", | |
"def construct_input_ref_token_type_pair(input_ids, sep_ind=0):\n", | |
" seq_len = input_ids.size(1)\n", | |
" token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)\n", | |
" ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1\n", | |
" return token_type_ids, ref_token_type_ids\n", | |
"\n", | |
"def construct_input_ref_pos_id_pair(input_ids):\n", | |
" seq_length = input_ids.size(1)\n", | |
" position_ids = torch.arange(seq_length, dtype=torch.long, device=device)\n", | |
" # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`\n", | |
" ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)\n", | |
"\n", | |
" position_ids = position_ids.unsqueeze(0).expand_as(input_ids)\n", | |
" ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)\n", | |
" return position_ids, ref_position_ids\n", | |
" \n", | |
"def construct_attention_mask(input_ids):\n", | |
" return torch.ones_like(input_ids)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "vhasPia4bUD8", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def custom_forward(inputs):\n", | |
" preds = predict(inputs)\n", | |
" return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pGwkb1vAbUEA", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EQlVDaISbUEF", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# One can test a couple of examples and check that the sentiment classifier is behaving\n", | |
"text = \"The movie was one of those amazing movies you can't forget.\"\n", | |
"#text = \"The movie was one of those crappy movies you can't forget.\"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "BtoFctjVbUEM", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)\n", | |
"token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)\n", | |
"position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)\n", | |
"attention_mask = construct_attention_mask(input_ids)\n", | |
"\n", | |
"indices = input_ids[0].detach().tolist()\n", | |
"all_tokens = tokenizer.convert_ids_to_tokens(indices)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "T4vlqBBrbUEY", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "1a54bbcf-955c-4043-ebf6-c4ff9997bbe4" | |
}, | |
"source": [ | |
"# Check predict output\n", | |
"predict(input_ids)" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[-3.3635, 4.0115]], device='cuda:0', grad_fn=<AddmmBackward>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 12 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wpNkwy6_bUEd", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "e4f6d662-5753-4baa-cb49-f9a3c35f21d9" | |
}, | |
"source": [ | |
"# Check output of custom_forward\n", | |
"custom_forward(input_ids)" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([0.0006], device='cuda:0', grad_fn=<UnsqueezeBackward0>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 13 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "YAzBqQlpbUEk", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"attributions, delta = lig.attribute(inputs=input_ids,\n", | |
" baselines=ref_input_ids,\n", | |
" return_convergence_delta=True)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dU8SRQFybUEo", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 54 | |
}, | |
"outputId": "bcadcde0-8ce7-4b0e-814c-6b14570f1e2a" | |
}, | |
"source": [ | |
"score = predict(input_ids)\n", | |
"\n", | |
"print('Sentence: ', text)\n", | |
"print('Sentiment: ' + str(torch.argmax(score[0]).cpu().numpy()) + \\\n", | |
" ', Probability positive: ' + str(torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()))" | |
], | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Sentence: The movie was one of those amazing movies you can't forget.\n", | |
"Sentiment: 1, Probability positive: 0.99937373\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Hq8R_ZYubUEu", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def summarize_attributions(attributions):\n", | |
" attributions = attributions.sum(dim=-1).squeeze(0)\n", | |
" attributions = attributions / torch.norm(attributions)\n", | |
" return attributions" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3q7xXwRrbUEx", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"attributions_sum = summarize_attributions(attributions)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0ZF0RmZ4bUE1", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# storing couple samples in an array for visualization purposes\n", | |
"score_vis = viz.VisualizationDataRecord(attributions_sum,\n", | |
" torch.softmax(score, dim = 1)[0][0],\n", | |
" torch.argmax(torch.softmax(score, dim = 1)[0]),\n", | |
" 0,\n", | |
" text,\n", | |
" attributions_sum.sum(), \n", | |
" all_tokens,\n", | |
" delta)\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-gAojuO6ody0", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 131 | |
}, | |
"outputId": "6875e039-c494-4b0f-ec78-ff192daf0918" | |
}, | |
"source": [ | |
"print('\\033[1m', 'Visualization For Score', '\\033[0m')\n", | |
"viz.visualize_text([score_vis])" | |
], | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\u001b[1m Visualization For Score \u001b[0m\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"<table width: 100%><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>0</b></text></td><td><text style=\"padding-right:2em\"><b>1 (0.00)</b></text></td><td><text style=\"padding-right:2em\"><b>The movie was one of those amazing movies you can't forget.</b></text></td><td><text style=\"padding-right:2em\"><b>-0.72</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0; line-height:1.75\"><font color=\"black\"> The </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 71%); opacity:1.0; line-height:1.75\"><font color=\"black\"> one </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> of </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> those </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0; line-height:1.75\"><font color=\"black\"> amazing </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movies </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> you </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> can </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0; line-height:1.75\"><font color=\"black\"> ' </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> t </font></mark><mark style=\"background-color: hsl(120, 75%, 98%); opacity:1.0; line-height:1.75\"><font color=\"black\"> forget </font></mark><mark style=\"background-color: hsl(0, 75%, 71%); opacity:1.0; line-height:1.75\"><font color=\"black\"> . </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ItXD4N9FogZu", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"The visualization is clearly meaningless! :(\n" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment