Created
July 26, 2018 14:30
-
-
Save tomonari-masada/81239ab6ee576033e680de99f58f848f to your computer and use it in GitHub Desktop.
packed_sequences.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "packed_sequences.ipynb", | |
"version": "0.3.2", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"[View in Colaboratory](https://colab.research.google.com/gist/tomonari-masada/81239ab6ee576033e680de99f58f848f/packed_sequences.ipynb)" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "9IQhmDmm0Tnl", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"import torch" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "Wo0vrvfz3GfS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"# http://pytorch.org/\n", | |
"from os import path\n", | |
"from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n", | |
"platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n", | |
"\n", | |
"accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'\n", | |
"\n", | |
"!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.0-{platform}-linux_x86_64.whl torchvision\n", | |
"import torch" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "AwkqSZOs0cR6", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"import torch.nn as nn" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "Q-FcGJ_z9Mvp", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"output_size = 10\n", | |
"h_dim = 3" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "7n-b1i930mXJ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"sentences = [[3, 1, 5, 2], [2, 5], [4, 2, 2]] # The smallest index is 1.\n", | |
"targets = [s[1:] + [10] for s in sentences] # 10 means EOS." | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "Nf0xR1fB9zGI", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"def decr(ss):\n", | |
" return [list(map(lambda x: x - 1, s)) for s in ss]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "FqI6EMPU5kvI", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "8a45a535-abb6-495e-a00a-7a304ed2a362" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"sentences = decr(sentences)\n", | |
"sentences" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[[2, 0, 4, 1], [1, 4], [3, 1, 1]]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "WQ8gYOYs-KfQ", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "5c6a7a97-c7e8-4e5b-b105-04b3c3b60cfd" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"targets = decr(targets)\n", | |
"targets" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[[0, 4, 1, 9], [4, 9], [1, 1, 9]]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "PCfshB4c0oph", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"emb = nn.Embedding(output_size, h_dim)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "bgQiTvu102Uq", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 215 | |
}, | |
"outputId": "79efb96e-ab28-450a-cd3e-89e7213eb6ee" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"emb.weight" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"Parameter containing:\n", | |
"tensor([[-0.1173, -0.4201, -0.6489],\n", | |
" [-2.4970, 0.9485, -0.1642],\n", | |
" [-1.7108, -1.9750, -0.4109],\n", | |
" [-0.6908, -0.3993, 0.6631],\n", | |
" [-1.9111, -0.1232, -1.2113],\n", | |
" [ 0.7985, -0.0895, 0.1405],\n", | |
" [ 1.1284, 0.4328, -1.3285],\n", | |
" [-1.4747, 0.3554, -1.4982],\n", | |
" [-1.7153, 0.0813, 0.8034],\n", | |
" [ 0.9740, 0.8225, 2.0396]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "yC45uK5w04qp", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "94ec20fc-d354-4e49-d292-5d769810bbf0" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"perm = sorted(range(len(sentences)), key=lambda x: len(sentences[x]), reverse=True)\n", | |
"perm" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[0, 2, 1]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 11 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "kikybbFn08Ja", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "2c17c0d0-51ac-4161-ad92-ce9cd45690eb" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"sentences = [sentences[i] for i in perm]\n", | |
"sentences" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[[2, 0, 4, 1], [3, 1, 1], [1, 4]]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 12 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "if6xfmDm-xsw", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "46374bb6-b21a-40a0-cdde-46543c6217bd" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"targets = [targets[i] for i in perm]\n", | |
"targets" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[[0, 4, 1, 9], [1, 1, 9], [4, 9]]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 13 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "neTu66DV42V6", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "f9b492ad-fecd-4dc6-dba7-2f4e1d3ada3d" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"nn.utils.rnn.pack_sequence([torch.tensor(s) for s in sentences])" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"PackedSequence(data=tensor([ 2, 3, 1, 0, 1, 4, 4, 1, 1]), batch_sizes=tensor([ 3, 3, 2, 1]))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 14 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "ifrJDWuO080B", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 179 | |
}, | |
"outputId": "ec1a27af-a34a-488e-de62-856ac0bf3541" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"packed = nn.utils.rnn.pack_sequence([emb(torch.tensor(s)) for s in sentences])\n", | |
"packed" | |
], | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"PackedSequence(data=tensor([[-1.7108, -1.9750, -0.4109],\n", | |
" [-0.6908, -0.3993, 0.6631],\n", | |
" [-2.4970, 0.9485, -0.1642],\n", | |
" [-0.1173, -0.4201, -0.6489],\n", | |
" [-2.4970, 0.9485, -0.1642],\n", | |
" [-1.9111, -0.1232, -1.2113],\n", | |
" [-1.9111, -0.1232, -1.2113],\n", | |
" [-2.4970, 0.9485, -0.1642],\n", | |
" [-2.4970, 0.9485, -0.1642]]), batch_sizes=tensor([ 3, 3, 2, 1]))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 15 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "OnppUP4D3fvR", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"lstm_layer = nn.LSTM(h_dim, h_dim)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "XxT60x-s3lhZ", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 287 | |
}, | |
"outputId": "6796acf3-470e-486d-9736-70e03d86ce59" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"output = lstm_layer(packed)\n", | |
"output" | |
], | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(PackedSequence(data=tensor([[-0.0594, 0.0435, -0.1823],\n", | |
" [-0.1311, 0.0501, -0.0162],\n", | |
" [-0.0069, 0.0480, 0.1320],\n", | |
" [ 0.0176, 0.0143, -0.2759],\n", | |
" [-0.1226, 0.0507, 0.1302],\n", | |
" [ 0.1162, 0.0678, -0.2150],\n", | |
" [ 0.1876, 0.0323, -0.2292],\n", | |
" [-0.1069, 0.0639, 0.1501],\n", | |
" [ 0.1483, 0.0571, 0.1160]]), batch_sizes=tensor([ 3, 3, 2, 1])),\n", | |
" (tensor([[[ 0.1483, 0.0571, 0.1160],\n", | |
" [-0.1069, 0.0639, 0.1501],\n", | |
" [ 0.1162, 0.0678, -0.2150]]]),\n", | |
" tensor([[[ 0.2251, 0.4000, 0.1716],\n", | |
" [-0.1619, 0.4457, 0.2130],\n", | |
" [ 0.1634, 0.3167, -0.3403]]])))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "mLE-QZc95T_C", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 179 | |
}, | |
"outputId": "b0fed3ff-363b-4fda-acc4-5d1add0415ed" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"output[0].data" | |
], | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.0594, 0.0435, -0.1823],\n", | |
" [-0.1311, 0.0501, -0.0162],\n", | |
" [-0.0069, 0.0480, 0.1320],\n", | |
" [ 0.0176, 0.0143, -0.2759],\n", | |
" [-0.1226, 0.0507, 0.1302],\n", | |
" [ 0.1162, 0.0678, -0.2150],\n", | |
" [ 0.1876, 0.0323, -0.2292],\n", | |
" [-0.1069, 0.0639, 0.1501],\n", | |
" [ 0.1483, 0.0571, 0.1160]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "hBEu63KB8ou4", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"dense = nn.Linear(h_dim, output_size)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "WhmX5pS98u75", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 341 | |
}, | |
"outputId": "c2014849-e599-4f66-c2ca-f7413574b0f9" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"output = dense(output[0].data)\n", | |
"output" | |
], | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.5985, -0.0634, 0.1536, -0.5625, -0.0715, -0.1758, -0.2652,\n", | |
" -0.4059, -0.1514, -0.2589],\n", | |
" [-0.5833, -0.1209, 0.2167, -0.5463, -0.1616, -0.0596, -0.2566,\n", | |
" -0.4300, -0.1707, -0.2158],\n", | |
" [-0.4590, -0.0849, 0.2036, -0.4992, -0.1926, -0.0379, -0.3052,\n", | |
" -0.4350, -0.1383, -0.2750],\n", | |
" [-0.5742, -0.0165, 0.0922, -0.5624, -0.0039, -0.2513, -0.2814,\n", | |
" -0.3876, -0.1343, -0.3041],\n", | |
" [-0.5266, -0.1385, 0.2446, -0.5192, -0.2214, 0.0121, -0.2706,\n", | |
" -0.4447, -0.1691, -0.2152],\n", | |
" [-0.5220, 0.0260, 0.0986, -0.5418, -0.0216, -0.2770, -0.3144,\n", | |
" -0.3892, -0.1004, -0.3485],\n", | |
" [-0.4696, 0.0580, 0.0525, -0.5280, 0.0131, -0.3074, -0.3352,\n", | |
" -0.3793, -0.0862, -0.3886],\n", | |
" [-0.5169, -0.1328, 0.2502, -0.5148, -0.2300, 0.0122, -0.2765,\n", | |
" -0.4464, -0.1631, -0.2216],\n", | |
" [-0.3809, -0.0090, 0.1513, -0.4771, -0.1501, -0.1176, -0.3499,\n", | |
" -0.4208, -0.0950, -0.3547]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 20 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "nEvN4_yM8HWx", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "82a30831-fc52-4d6c-d92e-70511449a381" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"targets = nn.utils.rnn.pack_sequence([torch.tensor(s) for s in targets])\n", | |
"targets" | |
], | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"PackedSequence(data=tensor([ 0, 1, 4, 4, 1, 9, 1, 9, 9]), batch_sizes=tensor([ 3, 3, 2, 1]))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 21 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "WWsSWtSQ7WhH", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"loss = nn.CrossEntropyLoss()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "4N9P2pTq7hu3", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"outputId": "76260bf6-1a5b-4af5-c34b-c764c1e86e86" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"loss(output, targets.data)" | |
], | |
"execution_count": 23, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensor(2.3085)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 23 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Hrzrb_nG7ia_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment