-
-
Save HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec to your computer and use it in GitHub Desktop.
import torch | |
from torch import LongTensor | |
from torch.nn import Embedding, LSTM | |
from torch.autograd import Variable | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium'] | |
# | |
# Step 1: Construct Vocabulary | |
# Step 2: Load indexed data (list of instances, where each instance is list of character indices) | |
# Step 3: Make Model | |
# * Step 4: Pad instances with 0s till max length sequence | |
# * Step 5: Sort instances by sequence length in descending order | |
# * Step 6: Embed the instances | |
# * Step 7: Call pack_padded_sequence with embeded instances and sequence lengths | |
# * Step 8: Forward with LSTM | |
# * Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector | |
# * Summary of Shape Transformations | |
# We want to run LSTM on a batch following 3 character sequences | |
seqs = ['long_str', # len = 8 | |
'tiny', # len = 4 | |
'medium'] # len = 6 | |
## Step 1: Construct Vocabulary ## | |
##------------------------------## | |
# make sure <pad> idx is 0 | |
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq])) | |
# => ['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y'] | |
## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ## | |
##-------------------------------------------------------------------------------------------------## | |
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs] | |
# vectorized_seqs => [[6, 9, 8, 4, 1, 11, 12, 10], | |
# [12, 5, 8, 14], | |
# [7, 3, 2, 5, 13, 7]] | |
## Step 3: Make Model ## | |
##--------------------## | |
embed = Embedding(len(vocab), 4) # embedding_dim = 4 | |
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5 | |
## Step 4: Pad instances with 0s till max length sequence ## | |
##--------------------------------------------------------## | |
# get the length of each seq in your batch | |
seq_lengths = LongTensor(list(map(len, vectorized_seqs))) | |
# seq_lengths => [ 8, 4, 6] | |
# batch_sum_seq_len: 8 + 4 + 6 = 18 | |
# max_seq_len: 8 | |
seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long() | |
# seq_tensor => [[0 0 0 0 0 0 0 0] | |
# [0 0 0 0 0 0 0 0] | |
# [0 0 0 0 0 0 0 0]] | |
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)): | |
seq_tensor[idx, :seqlen] = LongTensor(seq) | |
# seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str | |
# [12 5 8 14 0 0 0 0] # tiny | |
# [ 7 3 2 5 13 7 0 0]] # medium | |
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8) | |
## Step 5: Sort instances by sequence length in descending order ## | |
##---------------------------------------------------------------## | |
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) | |
seq_tensor = seq_tensor[perm_idx] | |
# seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str | |
# [ 7 3 2 5 13 7 0 0] # medium | |
# [12 5 8 14 0 0 0 0]] # tiny | |
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8) | |
## Step 6: Embed the instances ## | |
##-----------------------------## | |
embedded_seq_tensor = embed(seq_tensor) | |
# embedded_seq_tensor => | |
# [[[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l | |
# [-0.23622951 2.0361056 0.15435742 -0.04513785] o | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n | |
# [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] g | |
# [-1.6334635 -0.6100042 1.7509955 -1.931793 ] _ | |
# [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] s | |
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] t | |
# [-0.22739866 -0.45782727 -0.6643252 0.25129375]] r | |
# [[ 0.16031227 -0.08209462 -0.16297023 0.48121014] m | |
# [-0.7303265 -0.857339 0.58913064 -1.1068314 ] e | |
# [ 0.48159844 -1.4886451 0.92639893 0.76906884] d | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i | |
# [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] u | |
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] m | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]] <pad> | |
# [[ 0.64004815 0.45813003 0.3476034 -0.03451729] t | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n | |
# [-1.284392 0.68294704 1.4064184 -0.42879772] y | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]]] <pad> | |
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4) | |
## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ## | |
##-------------------------------------------------------------------------------## | |
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True) | |
# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes | |
# | |
# packed_input.data => | |
# [[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l | |
# [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] m | |
# [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] t | |
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] o | |
# [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] e | |
# [-1.284392 0.68294704 1.4064184 -0.42879772] i | |
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] n | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] d | |
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] n | |
# [-0.23622951 2.0361056 0.15435742 -0.04513785] g | |
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] i | |
# [-0.22739866 -0.45782727 -0.6643252 0.25129375]] y | |
# [-0.7303265 -0.857339 0.58913064 -1.1068314 ] _ | |
# [-1.6334635 -0.6100042 1.7509955 -1.931793 ] u | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] s | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] m | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] t | |
# [ 0.48159844 -1.4886451 0.92639893 0.76906884] r | |
# packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4) | |
# | |
# packed_input.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] | |
# visualization : | |
# l o n g _ s t r #(long_str) | |
# m e d i u m #(medium) | |
# t i n y #(tiny) | |
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len]) | |
## Step 8: Forward with LSTM ## | |
##---------------------------## | |
packed_output, (ht, ct) = lstm(packed_input) | |
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes | |
# | |
# packed_output.data : | |
# [[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l | |
# [ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] o | |
# [ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] n | |
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] g | |
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] _ | |
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] s | |
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] t | |
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] r | |
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] m | |
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] e | |
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] d | |
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] i | |
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] u | |
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] m | |
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] t | |
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] i | |
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] n | |
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] y | |
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5) | |
# packed_output.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] (same as packed_input.batch_sizes) | |
# visualization : | |
# l o n g _ s t r #(long_str) | |
# m e d i u m #(medium) | |
# t i n y #(tiny) | |
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len]) | |
## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ## | |
##------------------------------------------------------------------------------------## | |
# unpack your output if required | |
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True) | |
# output: | |
# output => | |
# [[[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l | |
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o | |
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n | |
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g | |
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _ | |
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s | |
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t | |
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r | |
# [[ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m | |
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e | |
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d | |
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i | |
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u | |
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ]] <pad> | |
# [[ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t | |
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i | |
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n | |
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ]]] <pad> | |
# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5) | |
# Or if you just want the final hidden state? | |
print(ht[-1]) | |
## Summary of Shape Transformations ## | |
##----------------------------------## | |
# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim) | |
# (batch_size X max_seq_len X embedding_dim) ---> Pack ---> (batch_sum_seq_len X embedding_dim) | |
# (batch_sum_seq_len X embedding_dim) ---> LSTM ---> (batch_sum_seq_len X hidden_dim) | |
# (batch_sum_seq_len X hidden_dim) ---> UnPack ---> (batch_size X max_seq_len X hidden_dim) |
Thanks a lot for putting this together.
Line#146 is the icing on the cake.
Awesome!
Great work!
Great work.
Thanks a lot! 👍
thank you, it is very helpful!
This is great! Congratulation
Bro where did the len object in line 51 come from?
Perfectly explained! Was always confused on what data goes into the batch.
Why sort instances by sequence length in descending order step is needed?
pack_padded_sequence does not need sorting anymore,its a parameter in the function (Doc)
**enforce_sorted** (bool, optional) –
if True, the input is expected to contain sequences sorted by length in a decreasing order. If False, the input will get sorted unconditionally. Default: True.
@jackfrost29 - len
is a built-in method in classes. When calling len
, it accesses the __len__
method for whatever object is used as input. The usual understanding with len
is that it finds the length / size of whatever object you pass to it. In this case, the object is a list of token lists so it finds the length of every token list in vectorized_seqs
.
Wonder why nobody complains about lines 120-138, as the packed sequence is clearly wrong.
Clearly, the first three rows in the packed sequence are not l, m, t but l, u, s for example. There are also too many closing brackets in line 132.
Pretty helpful, thank you
Thankyou very much.It's a very important paper.
you sort them, then you need back to original position right? I want to use a hidden state, is that right?
''' a_lengths, idx = text_length.sort(0, descending=True)
_, un_idx = t.sort(idx, dim=0)
seq = text[idx]
seq = self.dropout(self.embedding(seq))
a_packed_input = t.nn.utils.rnn.pack_padded_sequence(input=seq, lengths=a_lengths.to('cpu'), batch_first=True)
packed_output, (hidden, cell) = self.rnn(a_packed_input)
out, _ = t.nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
hidden = self.dropout(t.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
hidden = t.index_select(hidden, 0, un_idx)
just what i was looking for, thanks
I can't find any performance comparision. Did anyone compare using pack_padded_sequence
with just padded sequence?
Why sort instances by sequence length in descending order step is needed?
If you want to export this model as ONNX, enforce_sorted option must be True.
However, if this model is not to be used in production, you can set enforce_sorted=False to avoid sorting.
Superb!
Very helpful!
Most easy-to-understand explanation I have read !
awnsome !
I think packed_output
is wrong. It should be the same order as the packed_input
. Otherwise, calling pad_packed_sequence
will generate inconsistent behavior between packed_output
and packed_input
.
See this example
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
x = F.one_hot(torch.arange(9) % 3, num_classes=3).reshape(3, 3, 3).float()
length = torch.tensor([3, 1, 2], dtype=torch.int64)
print(x)
# tensor([[[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]],
# [[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]],
# [[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]]])
packed_x = pack_padded_sequence(x, length, batch_first=True, enforce_sorted=False)
print(packed_x)
# PackedSequence(data=tensor([[1., 0., 0.],
# [1., 0., 0.],
# [1., 0., 0.],
# [0., 1., 0.],
# [0., 1., 0.],
# [0., 0., 1.]]), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))
m = nn.RNN(3, 1, batch_first=True, bias=False)
print(m.all_weights)
# [[Parameter containing:
# tensor([[-0.7161, 0.8613, -0.8458]], requires_grad=True), Parameter containing:
# tensor([[0.2222]], requires_grad=True)]]
packed_lstm_out, _ = m.forward(packed_x)
print(packed_lstm_out)
# PackedSequence(data=tensor([[-0.6145],
# [-0.6145],
# [-0.6145],
# [ 0.6198],
# [ 0.6198],
# [-0.6094]], grad_fn=<CatBackward0>), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))
unpacked_lstm_out, unpacked_length = pad_packed_sequence(packed_lstm_out, batch_first=True)
print(unpacked_lstm_out)
# tensor([[[-0.6145],
# [ 0.6198],
# [-0.6094]],
# [[-0.6145],
# [ 0.0000],
# [ 0.0000]],
# [[-0.6145],
# [ 0.6198],
# [ 0.0000]]], grad_fn=<IndexSelectBackward0>)
This is very helpful. Thank you.
The explanation for packed_output, (ht, ct) = lstm(packed_input)
seems not correct.
should be:
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
#
# packed_output.data :
# [[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l
# [ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m
# [ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)
# packed_output.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] (same as packed_input.batch_sizes)
# visualization :
# l o n g _ s t r #(long_str)
# m e d i u m #(medium)
# t i n y #(tiny)
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len])
It's really help me to understand
Great work!