Last active
August 6, 2022 03:07
-
-
Save davidefiocco/47137f6eb7e3351c9bac4580c2ccc9d4 to your computer and use it in GitHub Desktop.
@NarineK solution for BERT interpretation in https://github.com/pytorch/captum/issues/311#issuecomment-612460705
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"orig_nbformat": 2, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"npconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": 3, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"colab": { | |
"name": "Copy of Interpretation.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/davidefiocco/47137f6eb7e3351c9bac4580c2ccc9d4/copy-of-interpretation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "UFESEuEgbUDD", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Interpretation of BertForSequenceClassification in captum\n", | |
"\n", | |
"In this notebook we use Captum to interpret a BERT sentiment classifier finetuned on the imdb dataset https://huggingface.co/lvwerra/bert-imdb " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EJ51JAxHbghp", | |
"colab_type": "code", | |
"outputId": "856672a7-298c-4f26-a108-338c7027aa47", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 884 | |
} | |
}, | |
"source": [ | |
"# Install dependencies\n", | |
"!pip install transformers\n", | |
"!pip install captum" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting transformers\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a3/78/92cedda05552398352ed9784908b834ee32a0bd071a9b32de287327370b7/transformers-2.8.0-py3-none-any.whl (563kB)\n", | |
"\u001b[K |████████████████████████████████| 573kB 7.2MB/s \n", | |
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)\n", | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.2)\n", | |
"Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.12.34)\n", | |
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)\n", | |
"Collecting sacremoses\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a6/b4/7a41d630547a4afd58143597d5a49e07bfd4c42914d8335b2a5657efc14b/sacremoses-0.0.38.tar.gz (860kB)\n", | |
"\u001b[K |████████████████████████████████| 870kB 25.0MB/s \n", | |
"\u001b[?25hCollecting sentencepiece\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n", | |
"\u001b[K |████████████████████████████████| 1.0MB 29.2MB/s \n", | |
"\u001b[?25hCollecting tokenizers==0.5.2\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/d1/3f/73c881ea4723e43c1e9acf317cf407fab3a278daab3a69c98dcac511c04f/tokenizers-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.7MB)\n", | |
"\u001b[K |████████████████████████████████| 3.7MB 49.4MB/s \n", | |
"\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.38.0)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)\n", | |
"Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)\n", | |
"Requirement already satisfied: botocore<1.16.0,>=1.15.34 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.15.34)\n", | |
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.5)\n", | |
"Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.3.3)\n", | |
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n", | |
"Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.1)\n", | |
"Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.1)\n", | |
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2019.11.28)\n", | |
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n", | |
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)\n", | |
"Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.34->boto3->transformers) (2.8.1)\n", | |
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.34->boto3->transformers) (0.15.2)\n", | |
"Building wheels for collected packages: sacremoses\n", | |
" Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for sacremoses: filename=sacremoses-0.0.38-cp36-none-any.whl size=884628 sha256=77a6b4aabd517c57817a0b235589cda947e4af206a8785014c0761c26c3a065e\n", | |
" Stored in directory: /root/.cache/pip/wheels/6d/ec/1a/21b8912e35e02741306f35f66c785f3afe94de754a0eaf1422\n", | |
"Successfully built sacremoses\n", | |
"Installing collected packages: sacremoses, sentencepiece, tokenizers, transformers\n", | |
"Successfully installed sacremoses-0.0.38 sentencepiece-0.1.85 tokenizers-0.5.2 transformers-2.8.0\n", | |
"Collecting captum\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/42/de/c018e206d463d9975444c28b0a4f103c9ca4b2faedf943df727e402a1a1e/captum-0.2.0-py3-none-any.whl (1.4MB)\n", | |
"\u001b[K |████████████████████████████████| 1.4MB 6.9MB/s \n", | |
"\u001b[?25hRequirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from captum) (3.2.1)\n", | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from captum) (1.18.2)\n", | |
"Requirement already satisfied: torch>=1.2 in /usr/local/lib/python3.6/dist-packages (from captum) (1.4.0)\n", | |
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->captum) (1.2.0)\n", | |
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->captum) (2.8.1)\n", | |
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->captum) (0.10.0)\n", | |
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->captum) (2.4.6)\n", | |
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.1->matplotlib->captum) (1.12.0)\n", | |
"Installing collected packages: captum\n", | |
"Successfully installed captum-0.2.0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CS9Kaz8ubUDG", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from transformers import BertTokenizer, BertForSequenceClassification, BertConfig\n", | |
"from captum.attr import visualization as viz\n", | |
"from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients\n", | |
"from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer\n", | |
"import torch\n", | |
"import matplotlib.pyplot as plt" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "P1yl1gdvbUDS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3U5XDt1Gb73t", | |
"colab_type": "code", | |
"outputId": "12031166-fbed-47f0-e081-61c1f4f447d7", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
} | |
}, | |
"source": [ | |
"# Get model and config files from https://huggingface.co/lvwerra/bert-imdb\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/special_tokens_map.json\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/tokenizer_config.json\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/training_args.bin\n", | |
"!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/vocab.txt" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"--2020-04-06 19:08:32-- https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json\n", | |
"Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.184.165\n", | |
"Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.184.165|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 1220 (1.2K) [application/json]\n", | |
"Saving to: ‘./model/config.json’\n", | |
"\n", | |
"\rconfig.json 0%[ ] 0 --.-KB/s \rconfig.json 100%[===================>] 1.19K --.-KB/s in 0s \n", | |
"\n", | |
"2020-04-06 19:08:32 (68.1 MB/s) - ‘./model/config.json’ saved [1220/1220]\n", | |
"\n", | |
"--2020-04-06 19:08:34-- https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin\n", | |
"Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.41.190\n", | |
"Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.41.190|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 1334420863 (1.2G) [application/octet-stream]\n", | |
"Saving to: ‘./model/pytorch_model.bin’\n", | |
"\n", | |
"pytorch_model.bin 100%[===================>] 1.24G 28.9MB/s in 44s \n", | |
"\n", | |
"2020-04-06 19:09:18 (29.1 MB/s) - ‘./model/pytorch_model.bin’ saved [1334420863/1334420863]\n", | |
"\n", | |
"--2020-04-06 19:09:20-- https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/special_tokens_map.json\n", | |
"Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.187.85\n", | |
"Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.187.85|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 112 [application/json]\n", | |
"Saving to: ‘./model/special_tokens_map.json’\n", | |
"\n", | |
"special_tokens_map. 100%[===================>] 112 --.-KB/s in 0s \n", | |
"\n", | |
"2020-04-06 19:09:20 (2.09 MB/s) - ‘./model/special_tokens_map.json’ saved [112/112]\n", | |
"\n", | |
"--2020-04-06 19:09:21-- https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/tokenizer_config.json\n", | |
"Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.29.14\n", | |
"Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.29.14|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 40 [application/json]\n", | |
"Saving to: ‘./model/tokenizer_config.json’\n", | |
"\n", | |
"tokenizer_config.js 100%[===================>] 40 --.-KB/s in 0s \n", | |
"\n", | |
"2020-04-06 19:09:21 (1.06 MB/s) - ‘./model/tokenizer_config.json’ saved [40/40]\n", | |
"\n", | |
"--2020-04-06 19:09:23-- https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/training_args.bin\n", | |
"Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.76.254\n", | |
"Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.76.254|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 1449 (1.4K) [application/octet-stream]\n", | |
"Saving to: ‘./model/training_args.bin’\n", | |
"\n", | |
"training_args.bin 100%[===================>] 1.42K --.-KB/s in 0s \n", | |
"\n", | |
"2020-04-06 19:09:23 (72.4 MB/s) - ‘./model/training_args.bin’ saved [1449/1449]\n", | |
"\n", | |
"--2020-04-06 19:09:25-- https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/vocab.txt\n", | |
"Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.80.75\n", | |
"Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.80.75|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 213450 (208K) [text/plain]\n", | |
"Saving to: ‘./model/vocab.txt’\n", | |
"\n", | |
"vocab.txt 100%[===================>] 208.45K 845KB/s in 0.2s \n", | |
"\n", | |
"2020-04-06 19:09:25 (845 KB/s) - ‘./model/vocab.txt’ saved [213450/213450]\n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "X-nyyq_tbUDa", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# load model\n", | |
"model = BertForSequenceClassification.from_pretrained('./model')\n", | |
"model.to(device)\n", | |
"model.eval()\n", | |
"model.zero_grad()\n", | |
"\n", | |
"# load tokenizer\n", | |
"tokenizer = BertTokenizer.from_pretrained('./model')" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JUMsvUOTbUDi", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def predict(inputs):\n", | |
" return model(inputs)[0]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "SIbauwGbbUDo", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"ref_token_id = tokenizer.pad_token_id # A token used for generating token reference\n", | |
"sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.\n", | |
"cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "mcnTCNUFbUD1", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):\n", | |
"\n", | |
" text_ids = tokenizer.encode(text, add_special_tokens=False)\n", | |
" # construct input token ids\n", | |
" input_ids = [cls_token_id] + text_ids + [sep_token_id]\n", | |
" # construct reference token ids \n", | |
" ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]\n", | |
"\n", | |
" return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)\n", | |
"\n", | |
"def construct_input_ref_token_type_pair(input_ids, sep_ind=0):\n", | |
" seq_len = input_ids.size(1)\n", | |
" token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)\n", | |
" ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1\n", | |
" return token_type_ids, ref_token_type_ids\n", | |
"\n", | |
"def construct_input_ref_pos_id_pair(input_ids):\n", | |
" seq_length = input_ids.size(1)\n", | |
" position_ids = torch.arange(seq_length, dtype=torch.long, device=device)\n", | |
" # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`\n", | |
" ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)\n", | |
"\n", | |
" position_ids = position_ids.unsqueeze(0).expand_as(input_ids)\n", | |
" ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)\n", | |
" return position_ids, ref_position_ids\n", | |
" \n", | |
"def construct_attention_mask(input_ids):\n", | |
" return torch.ones_like(input_ids)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "vhasPia4bUD8", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def custom_forward(inputs):\n", | |
" preds = predict(inputs)\n", | |
" return torch.softmax(preds, dim = 1)[0][1].unsqueeze(-1)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pGwkb1vAbUEA", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EQlVDaISbUEF", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# One can test a couple of examples and check that the sentiment classifier is behaving\n", | |
"text = \"The movie was one of those amazing movies\"#\"The movie was one of those amazing movies you can not forget\"\n", | |
"#text = \"The movie was one of those crappy movies you can't forget.\"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "BtoFctjVbUEM", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)\n", | |
"token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)\n", | |
"position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)\n", | |
"attention_mask = construct_attention_mask(input_ids)\n", | |
"\n", | |
"indices = input_ids[0].detach().tolist()\n", | |
"all_tokens = tokenizer.convert_ids_to_tokens(indices)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "T4vlqBBrbUEY", | |
"colab_type": "code", | |
"outputId": "f44b69f0-06f6-4202-aa16-99b24c1dc8b0", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"# Check predict output\n", | |
"predict(input_ids)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[-3.1333, 3.6520]], device='cuda:0', grad_fn=<AddmmBackward>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 26 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gcuf6-v1v952", | |
"colab_type": "code", | |
"outputId": "86097a72-6c8d-4e18-e8b2-0e115bca3945", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"pred = predict(input_ids)\n", | |
"torch.softmax(pred, dim = 1)\n" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[0.0011, 0.9989]], device='cuda:0', grad_fn=<SoftmaxBackward>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 27 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wpNkwy6_bUEd", | |
"colab_type": "code", | |
"outputId": "491b4b9d-ea96-492c-b0e6-87070bad077c", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"# Check output of custom_forward\n", | |
"custom_forward(input_ids)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([0.9989], device='cuda:0', grad_fn=<UnsqueezeBackward0>)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 28 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "YAzBqQlpbUEk", | |
"colab_type": "code", | |
"outputId": "d96b1d65-38ac-4fd4-970a-86570893246c", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"source": [ | |
"attributions, delta = lig.attribute(inputs=input_ids,\n", | |
" baselines=ref_input_ids,\n", | |
" n_steps=700,\n", | |
" internal_batch_size=3,\n", | |
" return_convergence_delta=True)\n", | |
"delta" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([-0.5689], device='cuda:0')" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 29 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dU8SRQFybUEo", | |
"colab_type": "code", | |
"outputId": "526898f4-3b6d-4946-c9aa-71e402c84753", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 51 | |
} | |
}, | |
"source": [ | |
"score = predict(input_ids)\n", | |
"\n", | |
"print('Sentence: ', text)\n", | |
"print('Sentiment: ' + str(torch.argmax(score[0]).cpu().numpy()) + \\\n", | |
" ', Probability positive: ' + str(torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()))" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Sentence: The movie was one of those amazing movies\n", | |
"Sentiment: 1, Probability positive: 0.9988709\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Hq8R_ZYubUEu", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def summarize_attributions(attributions):\n", | |
" attributions = attributions.sum(dim=-1).squeeze(0)\n", | |
" attributions = attributions / torch.norm(attributions)\n", | |
" return attributions" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3q7xXwRrbUEx", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"attributions_sum = summarize_attributions(attributions)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0ZF0RmZ4bUE1", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# storing couple samples in an array for visualization purposes\n", | |
"score_vis = viz.VisualizationDataRecord(attributions_sum,\n", | |
" torch.softmax(score, dim = 1)[0][1],\n", | |
" torch.argmax(torch.softmax(score, dim = 0)[0]),\n", | |
" 1,\n", | |
" text,\n", | |
" attributions_sum.sum(), \n", | |
" all_tokens,\n", | |
" delta)\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-gAojuO6ody0", | |
"colab_type": "code", | |
"outputId": "03a72a37-2e04-495e-a1e0-5a191e34733e", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 87 | |
} | |
}, | |
"source": [ | |
"print('\\033[1m', 'Visualization For Score', '\\033[0m')\n", | |
"viz.visualize_text([score_vis])" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\u001b[1m Visualization For Score \u001b[0m\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"<table width: 100%><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>1</b></text></td><td><text style=\"padding-right:2em\"><b>1 (1.00)</b></text></td><td><text style=\"padding-right:2em\"><b>The movie was one of those amazing movies</b></text></td><td><text style=\"padding-right:2em\"><b>0.26</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [CLS] </font></mark><mark style=\"background-color: hsl(0, 75%, 91%); opacity:1.0; line-height:1.75\"><font color=\"black\"> The </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movie </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0; line-height:1.75\"><font color=\"black\"> was </font></mark><mark style=\"background-color: hsl(120, 75%, 90%); opacity:1.0; line-height:1.75\"><font color=\"black\"> one </font></mark><mark style=\"background-color: hsl(120, 75%, 89%); opacity:1.0; line-height:1.75\"><font color=\"black\"> of </font></mark><mark style=\"background-color: hsl(0, 75%, 87%); opacity:1.0; line-height:1.75\"><font color=\"black\"> those </font></mark><mark style=\"background-color: hsl(120, 75%, 60%); opacity:1.0; line-height:1.75\"><font color=\"black\"> amazing </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0; line-height:1.75\"><font color=\"black\"> movies </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0; line-height:1.75\"><font color=\"black\"> [SEP] </font></mark></td><tr></table>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ItXD4N9FogZu", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"The visualization is clearly meaningless! :(\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "N-njUmIyzzGI", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"torch.softmax(score, dim = 1)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "leiLQGsVmISM", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@NarineK solution in pytorch/captum#311 (comment)