Created
March 29, 2023 05:18
-
-
Save myazdani/f71adc24fb57fd6a40855b832e555b40 to your computer and use it in GitHub Desktop.
-1 trick for masking subsets of sequence for loss.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyPj2ODa51/hvugZO/9voyvR", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/myazdani/f71adc24fb57fd6a40855b832e555b40/-1-trick-for-masking-subsets-of-sequence-for-loss.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# -1 trick for masking subsets of a sequence to be ignored in calculating loss\n", | |
"\n", | |
"We want to predict a sequence that looks like this:\n", | |
"```\n", | |
"[-1, -1, -1, 8, 2, 1]\n", | |
"```\n", | |
"\n", | |
"But we only care about predicting the `[8, 1, 1]` bit. The `-1` bits just signal to us to ignore those parts of the sequence. This is because sometimes we might have sequences that are of length 3 (like this example), sometimes length 6, length 1, and so on. We want to fit all tensors to size 6 though so we revert to using a dummy symbol like `-1` to mask out the parts we don't care about. \n", | |
"\n", | |
"Lets verify that `-1` is a good choice for ignoring the subsets of the sequence when using the cross entropy loss. \n", | |
"\n", | |
"\n", | |
"\n", | |
"First, a starter example. We want to predict this sequence (that contains no masking):\n", | |
"```\n", | |
"[9, 1, 3, 7, 8, 2]\n", | |
"```" | |
], | |
"metadata": { | |
"id": "UTziS0zttelL" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "egBUg18fstC0" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"target_seq = torch.tensor([9, 1, 3, 7, 8, 2])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"We'll have a helper function that turns a sequence into dummy logit scores." | |
], | |
"metadata": { | |
"id": "SLuTOtHKukYH" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def seq_to_logit(input_seq, max_score = 1, max_class=10):\n", | |
" scores = torch.zeros(len(input_seq), max_class)\n", | |
" for i in range(len(input_seq)):\n", | |
" scores[i,input_seq[i]] = max_score\n", | |
" return scores" | |
], | |
"metadata": { | |
"id": "nlB2c8-rujLm" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"seq_to_logit(target_seq)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "qALpDigzuat9", | |
"outputId": "50d27fcc-5001-413c-df21-eb0519481aa0" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n", | |
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],\n", | |
" [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"For the sequence `[9, 1, 3, 7, 8, 2]` you can see the logit scores have a value of 0 everywhere except where the sequence appears. So for the first element of the sequence, `9`, we see a value of `1.0` at the 10th position in the first row of the logits. \n", | |
"\n", | |
"We don't need to put `1.0`. In fact, to \"simulate\" the idea of putting all the probability mass in the correct place, we'll use something much bigger:\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "4CQS_4pbvSgz" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"logits = seq_to_logit(target_seq, 1000)\n", | |
"logits" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "eJCtJuOnvOzP", | |
"outputId": "9121ad4f-9a28-45bd-8758-b3785f31ac48" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 1000.],\n", | |
" [ 0., 1000., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 1000., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0., 0., 0., 1000., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0., 0., 0., 0., 1000., 0.],\n", | |
" [ 0., 0., 1000., 0., 0., 0., 0., 0., 0., 0.]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Let's see what the loss of this set of logit scores are against the target sequence:" | |
], | |
"metadata": { | |
"id": "H-1t0Kduw0DE" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"torch.nn.functional.cross_entropy(logits, target_seq, ignore_index=-1)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "o8kdWrVDvwGr", | |
"outputId": "dce9878e-baf0-4193-b93e-faf4cf0c8c63" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor(0.)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Great, as we expect the loss is 0.0. This makes sense!\n", | |
"\n", | |
"Now lets try the sequence with masking:\n", | |
"```\n", | |
"[-1, -1, -1, 8, 2, 1]\n", | |
"```" | |
], | |
"metadata": { | |
"id": "FwSU2VJPwzFn" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"masked_seq = torch.tensor([-1, -1, -1, 8, 2, 1])\n", | |
"pred_seq = torch.tensor([3, 4, 3, 8, 2, 1])\n", | |
"pred_logits = seq_to_logit(pred_seq, 1000)\n", | |
"pred_logits" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Nhn4evgIv6Eh", | |
"outputId": "40985462-11d3-4f4f-df2d-754400d5a749" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0., 0., 0., 1000., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 1000., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 1000., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0., 0., 0., 0., 1000., 0.],\n", | |
" [ 0., 0., 1000., 0., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 1000., 0., 0., 0., 0., 0., 0., 0., 0.]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"In this case the prediction sequence is `[3, 4, 3, 8, 2, 1]` but note the first 3 elements are ignored because the masked sequence is masking out those values. In other words, we don't care what we predict in the first 3 elements since the masking only cares about the last three. " | |
], | |
"metadata": { | |
"id": "hXmHdUBLxBCV" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"torch.nn.functional.cross_entropy(pred_logits, masked_seq, ignore_index=-1)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "7e172O_6wRdL", | |
"outputId": "9ea36590-33fd-4c18-ea2c-a04343d67919" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor(0.)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Sure enough, we still get 0 loss. Lets verify by changing the pred sequence to predict something wrong. " | |
], | |
"metadata": { | |
"id": "VIzaFgRGxTqj" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"masked_seq = torch.tensor([-1, -1, -1, 8, 2, 1])\n", | |
"pred_seq = torch.tensor([3, 4, 3, 8, 2, 8])\n", | |
"pred_logits = seq_to_logit(pred_seq, 1000)\n", | |
"torch.nn.functional.cross_entropy(pred_logits, masked_seq, ignore_index=-1)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "4s-MXfOiwY23", | |
"outputId": "24818728-3775-49b3-9de8-6d0549dcc444" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor(333.3333)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Now we do see a loss as we expect. " | |
], | |
"metadata": { | |
"id": "HRz2nz3exev8" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "usIgjSArxcsg" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment