-
-
Save HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec to your computer and use it in GitHub Desktop.
Minimal tutorial on packing (pack_padded_sequence) and unpacking (pad_packed_sequence) sequences in pytorch.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import LongTensor | |
from torch.nn import Embedding, LSTM | |
from torch.autograd import Variable | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium'] | |
# | |
# Step 1: Construct Vocabulary | |
# Step 2: Load indexed data (list of instances, where each instance is list of character indices) | |
# Step 3: Make Model | |
# * Step 4: Pad instances with 0s till max length sequence | |
# * Step 5: Sort instances by sequence length in descending order | |
# * Step 6: Embed the instances | |
# * Step 7: Call pack_padded_sequence with embeded instances and sequence lengths | |
# * Step 8: Forward with LSTM | |
# * Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector | |
# * Summary of Shape Transformations | |
# We want to run LSTM on a batch following 3 character sequences | |
seqs = ['long_str', # len = 8 | |
'tiny', # len = 4 | |
'medium'] # len = 6 | |
## Step 1: Construct Vocabulary ## | |
##------------------------------## | |
# make sure <pad> idx is 0 | |
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq])) | |
# => ['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y'] | |
## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ## | |
##-------------------------------------------------------------------------------------------------## | |
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs] | |
# vectorized_seqs => [[6, 9, 8, 4, 1, 11, 12, 10], | |
# [12, 5, 8, 14], | |
# [7, 3, 2, 5, 13, 7]] | |
## Step 3: Make Model ## | |
##--------------------## | |
embed = Embedding(len(vocab), 4) # embedding_dim = 4 | |
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5 | |
## Step 4: Pad instances with 0s till max length sequence ## | |
##--------------------------------------------------------## | |
# get the length of each seq in your batch | |
seq_lengths = LongTensor(list(map(len, vectorized_seqs))) | |
# seq_lengths => [ 8, 4, 6] | |
# batch_sum_seq_len: 8 + 4 + 6 = 18 | |
# max_seq_len: 8 | |
seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long() | |
# seq_tensor => [[0 0 0 0 0 0 0 0] | |
# [0 0 0 0 0 0 0 0] | |
# [0 0 0 0 0 0 0 0]] | |
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)): | |
seq_tensor[idx, :seqlen] = LongTensor(seq) | |
# seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str | |
# [12 5 8 14 0 0 0 0] # tiny | |
# [ 7 3 2 5 13 7 0 0]] # medium | |
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8) | |
## Step 5: Sort instances by sequence length in descending order ## | |
##---------------------------------------------------------------## | |
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) | |
seq_tensor = seq_tensor[perm_idx] | |
# seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str | |
# [ 7 3 2 5 13 7 0 0] # medium | |
# [12 5 8 14 0 0 0 0]] # tiny | |
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8) | |
## Step 6: Embed the instances ## | |
##-----------------------------## | |
embedded_seq_tensor = embed(seq_tensor) | |
# embedded_seq_tensor => | |
# [[[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l | |
# [-0.23622951 2.0361056 0.15435742 -0.04513785] o | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n | |
# [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] g | |
# [-1.6334635 -0.6100042 1.7509955 -1.931793 ] _ | |
# [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] s | |
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] t | |
# [-0.22739866 -0.45782727 -0.6643252 0.25129375]] r | |
# [[ 0.16031227 -0.08209462 -0.16297023 0.48121014] m | |
# [-0.7303265 -0.857339 0.58913064 -1.1068314 ] e | |
# [ 0.48159844 -1.4886451 0.92639893 0.76906884] d | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i | |
# [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] u | |
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] m | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]] <pad> | |
# [[ 0.64004815 0.45813003 0.3476034 -0.03451729] t | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n | |
# [-1.284392 0.68294704 1.4064184 -0.42879772] y | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]]] <pad> | |
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4) | |
## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ## | |
##-------------------------------------------------------------------------------## | |
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True) | |
# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes | |
# | |
# packed_input.data => | |
# [[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l | |
# [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] m | |
# [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] t | |
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] o | |
# [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] e | |
# [-1.284392 0.68294704 1.4064184 -0.42879772] i | |
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] n | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] d | |
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] n | |
# [-0.23622951 2.0361056 0.15435742 -0.04513785] g | |
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] i | |
# [-0.22739866 -0.45782727 -0.6643252 0.25129375]] y | |
# [-0.7303265 -0.857339 0.58913064 -1.1068314 ] _ | |
# [-1.6334635 -0.6100042 1.7509955 -1.931793 ] u | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] s | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] m | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] t | |
# [ 0.48159844 -1.4886451 0.92639893 0.76906884] r | |
# packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4) | |
# | |
# packed_input.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] | |
# visualization : | |
# l o n g _ s t r #(long_str) | |
# m e d i u m #(medium) | |
# t i n y #(tiny) | |
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len]) | |
## Step 8: Forward with LSTM ## | |
##---------------------------## | |
packed_output, (ht, ct) = lstm(packed_input) | |
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes | |
# | |
# packed_output.data : | |
# [[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l | |
# [ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] o | |
# [ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] n | |
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] g | |
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] _ | |
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] s | |
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] t | |
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] r | |
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] m | |
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] e | |
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] d | |
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] i | |
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] u | |
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] m | |
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] t | |
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] i | |
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] n | |
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] y | |
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5) | |
# packed_output.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] (same as packed_input.batch_sizes) | |
# visualization : | |
# l o n g _ s t r #(long_str) | |
# m e d i u m #(medium) | |
# t i n y #(tiny) | |
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len]) | |
## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ## | |
##------------------------------------------------------------------------------------## | |
# unpack your output if required | |
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True) | |
# output: | |
# output => | |
# [[[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l | |
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o | |
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n | |
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g | |
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _ | |
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s | |
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t | |
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r | |
# [[ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m | |
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e | |
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d | |
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i | |
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u | |
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ]] <pad> | |
# [[ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t | |
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i | |
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n | |
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ]]] <pad> | |
# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5) | |
# Or if you just want the final hidden state? | |
print(ht[-1]) | |
## Summary of Shape Transformations ## | |
##----------------------------------## | |
# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim) | |
# (batch_size X max_seq_len X embedding_dim) ---> Pack ---> (batch_sum_seq_len X embedding_dim) | |
# (batch_sum_seq_len X embedding_dim) ---> LSTM ---> (batch_sum_seq_len X hidden_dim) | |
# (batch_sum_seq_len X hidden_dim) ---> UnPack ---> (batch_size X max_seq_len X hidden_dim) |
The explanation for packed_output, (ht, ct) = lstm(packed_input)
seems not correct.
should be:
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
#
# packed_output.data :
# [[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l
# [ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m
# [ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)
# packed_output.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] (same as packed_input.batch_sizes)
# visualization :
# l o n g _ s t r #(long_str)
# m e d i u m #(medium)
# t i n y #(tiny)
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len])
It's really help me to understand
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is very helpful. Thank you.