Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save myazdani/f71adc24fb57fd6a40855b832e555b40 to your computer and use it in GitHub Desktop.
Save myazdani/f71adc24fb57fd6a40855b832e555b40 to your computer and use it in GitHub Desktop.
-1 trick for masking subsets of sequence for loss.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPj2ODa51/hvugZO/9voyvR",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/myazdani/f71adc24fb57fd6a40855b832e555b40/-1-trick-for-masking-subsets-of-sequence-for-loss.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# -1 trick for masking subsets of a sequence to be ignored in calculating loss\n",
"\n",
"We want to predict a sequence that looks like this:\n",
"```\n",
"[-1, -1, -1, 8, 2, 1]\n",
"```\n",
"\n",
"But we only care about predicting the `[8, 1, 1]` bit. The `-1` bits just signal to us to ignore those parts of the sequence. This is because sometimes we might have sequences that are of length 3 (like this example), sometimes length 6, length 1, and so on. We want to fit all tensors to size 6 though so we revert to using a dummy symbol like `-1` to mask out the parts we don't care about. \n",
"\n",
"Lets verify that `-1` is a good choice for ignoring the subsets of the sequence when using the cross entropy loss. \n",
"\n",
"\n",
"\n",
"First, a starter example. We want to predict this sequence (that contains no masking):\n",
"```\n",
"[9, 1, 3, 7, 8, 2]\n",
"```"
],
"metadata": {
"id": "UTziS0zttelL"
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "egBUg18fstC0"
},
"outputs": [],
"source": [
"import torch\n",
"target_seq = torch.tensor([9, 1, 3, 7, 8, 2])"
]
},
{
"cell_type": "markdown",
"source": [
"We'll have a helper function that turns a sequence into dummy logit scores."
],
"metadata": {
"id": "SLuTOtHKukYH"
}
},
{
"cell_type": "code",
"source": [
"def seq_to_logit(input_seq, max_score = 1, max_class=10):\n",
" scores = torch.zeros(len(input_seq), max_class)\n",
" for i in range(len(input_seq)):\n",
" scores[i,input_seq[i]] = max_score\n",
" return scores"
],
"metadata": {
"id": "nlB2c8-rujLm"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"seq_to_logit(target_seq)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qALpDigzuat9",
"outputId": "50d27fcc-5001-413c-df21-eb0519481aa0"
},
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],\n",
" [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])"
]
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "markdown",
"source": [
"For the sequence `[9, 1, 3, 7, 8, 2]` you can see the logit scores have a value of 0 everywhere except where the sequence appears. So for the first element of the sequence, `9`, we see a value of `1.0` at the 10th position in the first row of the logits. \n",
"\n",
"We don't need to put `1.0`. In fact, to \"simulate\" the idea of putting all the probability mass in the correct place, we'll use something much bigger:\n",
"\n"
],
"metadata": {
"id": "4CQS_4pbvSgz"
}
},
{
"cell_type": "code",
"source": [
"logits = seq_to_logit(target_seq, 1000)\n",
"logits"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eJCtJuOnvOzP",
"outputId": "9121ad4f-9a28-45bd-8758-b3785f31ac48"
},
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 1000.],\n",
" [ 0., 1000., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 0., 0., 1000., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 0., 0., 0., 0., 0., 0., 1000., 0., 0.],\n",
" [ 0., 0., 0., 0., 0., 0., 0., 0., 1000., 0.],\n",
" [ 0., 0., 1000., 0., 0., 0., 0., 0., 0., 0.]])"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "markdown",
"source": [
"Let's see what the loss of this set of logit scores are against the target sequence:"
],
"metadata": {
"id": "H-1t0Kduw0DE"
}
},
{
"cell_type": "code",
"source": [
"torch.nn.functional.cross_entropy(logits, target_seq, ignore_index=-1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "o8kdWrVDvwGr",
"outputId": "dce9878e-baf0-4193-b93e-faf4cf0c8c63"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(0.)"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"source": [
"Great, as we expect the loss is 0.0. This makes sense!\n",
"\n",
"Now lets try the sequence with masking:\n",
"```\n",
"[-1, -1, -1, 8, 2, 1]\n",
"```"
],
"metadata": {
"id": "FwSU2VJPwzFn"
}
},
{
"cell_type": "code",
"source": [
"masked_seq = torch.tensor([-1, -1, -1, 8, 2, 1])\n",
"pred_seq = torch.tensor([3, 4, 3, 8, 2, 1])\n",
"pred_logits = seq_to_logit(pred_seq, 1000)\n",
"pred_logits"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Nhn4evgIv6Eh",
"outputId": "40985462-11d3-4f4f-df2d-754400d5a749"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0., 0., 0., 1000., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 0., 0., 0., 1000., 0., 0., 0., 0., 0.],\n",
" [ 0., 0., 0., 1000., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 0., 0., 0., 0., 0., 0., 0., 1000., 0.],\n",
" [ 0., 0., 1000., 0., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 1000., 0., 0., 0., 0., 0., 0., 0., 0.]])"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "markdown",
"source": [
"In this case the prediction sequence is `[3, 4, 3, 8, 2, 1]` but note the first 3 elements are ignored because the masked sequence is masking out those values. In other words, we don't care what we predict in the first 3 elements since the masking only cares about the last three. "
],
"metadata": {
"id": "hXmHdUBLxBCV"
}
},
{
"cell_type": "code",
"source": [
"torch.nn.functional.cross_entropy(pred_logits, masked_seq, ignore_index=-1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7e172O_6wRdL",
"outputId": "9ea36590-33fd-4c18-ea2c-a04343d67919"
},
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(0.)"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "markdown",
"source": [
"Sure enough, we still get 0 loss. Lets verify by changing the pred sequence to predict something wrong. "
],
"metadata": {
"id": "VIzaFgRGxTqj"
}
},
{
"cell_type": "code",
"source": [
"masked_seq = torch.tensor([-1, -1, -1, 8, 2, 1])\n",
"pred_seq = torch.tensor([3, 4, 3, 8, 2, 8])\n",
"pred_logits = seq_to_logit(pred_seq, 1000)\n",
"torch.nn.functional.cross_entropy(pred_logits, masked_seq, ignore_index=-1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4s-MXfOiwY23",
"outputId": "24818728-3775-49b3-9de8-6d0549dcc444"
},
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(333.3333)"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "markdown",
"source": [
"Now we do see a loss as we expect. "
],
"metadata": {
"id": "HRz2nz3exev8"
}
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "usIgjSArxcsg"
},
"execution_count": 8,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment