Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jojonki/95fed8a6f743741201f5cd9d807e3330 to your computer and use it in GitHub Desktop.
Save jojonki/95fed8a6f743741201f5cd9d807e3330 to your computer and use it in GitHub Desktop.
A simple usage of pack_padded_sequence and pad_packed_sequence in pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.autograd import Variable\n",
"from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
"from itertools import chain"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequences: [['hello'], ['hi', 'how', 'are', 'you'], ['good', 'thank', 'you']]\n",
"max_seq_len: 4\n",
"Vectorized: [[3], [4, 5, 1, 7], [2, 6, 7]]\n"
]
}
],
"source": [
"seqs = ['hello', 'hi how are you','good thank you']\n",
"seqs = [s.split(' ') for s in seqs]\n",
"print ('Sequences:', seqs)\n",
"max_seq_len = max(len(s) for s in seqs)\n",
"print('max_seq_len:', max_seq_len)\n",
"vocab = ['<pad>'] + sorted(list(set(chain.from_iterable(seqs))))\n",
"\n",
"vocab_size = len(vocab)\n",
"embd_size = 12 # E\n",
"hidden_size = 5 # H\n",
"\n",
"embed = nn.Embedding(len(vocab), 12)\n",
"lstm = nn.LSTM(12, 5)\n",
"\n",
"vec_seqs = [[vocab.index(w) for w in s] for s in seqs]\n",
"print('Vectorized:', vec_seqs)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sorted seq_tensor Variable containing:\n",
" 4 5 1 7\n",
" 2 6 7 0\n",
" 3 0 0 0\n",
"[torch.LongTensor of size 3x4]\n",
"\n",
"--- torch.Size([4, 3])\n",
"packed_input: torch.Size([8, 12])\n",
"packed_output: torch.Size([8, 5])\n",
"output: torch.Size([4, 3, 5])\n",
"Variable containing:\n",
"(0 ,.,.) = \n",
" -0.1543 -0.1183 -0.0109 0.0120 -0.1675\n",
" -0.1021 -0.0724 0.3059 -0.1867 -0.0173\n",
" 0.3118 0.0151 0.2762 0.2161 0.0431\n",
"\n",
"(1 ,.,.) = \n",
" -0.0578 -0.0463 0.1577 0.1174 -0.1212\n",
" -0.0393 0.1662 -0.0028 -0.0290 -0.0161\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000\n",
"\n",
"(2 ,.,.) = \n",
" 0.0893 -0.0111 0.1398 -0.2752 -0.0908\n",
" -0.0861 0.1073 0.0872 0.0157 -0.0729\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000\n",
"\n",
"(3 ,.,.) = \n",
" -0.0328 -0.0071 0.1282 -0.0450 -0.1021\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000\n",
"[torch.FloatTensor of size 4x3x5]\n",
"\n",
"torch.Size([3, 5])\n"
]
}
],
"source": [
"seq_lengths = torch.LongTensor([len(seq) for seq in vec_seqs])\n",
"\n",
"seq_tensor = []\n",
"for seq in vec_seqs:\n",
" pad_len = max(0, max_seq_len - len(seq))\n",
" seq_tensor.append(seq + [0] * pad_len)\n",
"seq_tensor = Variable(torch.LongTensor(seq_tensor))\n",
"\n",
"# sort the tensors for pack_padded_sequence\n",
"seq_lengths, seq_indices = seq_lengths.sort(0, descending=True)\n",
"seq_tensor = seq_tensor[seq_indices]\n",
"print('sorted seq_tensor', seq_tensor)\n",
"\n",
"# transpose the tensor for RNN (batch_first is false in default)\n",
"seq_tensor = seq_tensor.transpose(0,1) # (B,L) -> (L,B). B: batch_size, L: Length, \n",
"\n",
"seq_tensor = embed(seq_tensor) # (L, B, E)\n",
"packed_input = pack_padded_sequence(seq_tensor, seq_lengths.numpy()) # (sum_of_valid_token=8, E)\n",
"print('packed_input:', packed_input.data.size())\n",
"\n",
"packed_output, (ht, ct) = lstm(packed_input)\n",
"print('packed_output:', packed_output.data.size()) # (sum_of_valid_token=8, H)\n",
"\n",
"# unpack the output with paddings\n",
"output, _ = pad_packed_sequence(packed_output) # (L, bs, H)\n",
"print ('output:', output.size())\n",
"print(output) # you can see 0 values of padding\n",
"\n",
"# final hidden state\n",
"print (ht[-1].size()) # (B, H)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment