-
-
Save HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec to your computer and use it in GitHub Desktop.
import torch | |
from torch import LongTensor | |
from torch.nn import Embedding, LSTM | |
from torch.autograd import Variable | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium'] | |
# | |
# Step 1: Construct Vocabulary | |
# Step 2: Load indexed data (list of instances, where each instance is list of character indices) | |
# Step 3: Make Model | |
# * Step 4: Pad instances with 0s till max length sequence | |
# * Step 5: Sort instances by sequence length in descending order | |
# * Step 6: Embed the instances | |
# * Step 7: Call pack_padded_sequence with embeded instances and sequence lengths | |
# * Step 8: Forward with LSTM | |
# * Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector | |
# * Summary of Shape Transformations | |
# We want to run LSTM on a batch following 3 character sequences | |
seqs = ['long_str', # len = 8 | |
'tiny', # len = 4 | |
'medium'] # len = 6 | |
## Step 1: Construct Vocabulary ## | |
##------------------------------## | |
# make sure <pad> idx is 0 | |
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq])) | |
# => ['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y'] | |
## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ## | |
##-------------------------------------------------------------------------------------------------## | |
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs] | |
# vectorized_seqs => [[6, 9, 8, 4, 1, 11, 12, 10], | |
# [12, 5, 8, 14], | |
# [7, 3, 2, 5, 13, 7]] | |
## Step 3: Make Model ## | |
##--------------------## | |
embed = Embedding(len(vocab), 4) # embedding_dim = 4 | |
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5 | |
## Step 4: Pad instances with 0s till max length sequence ## | |
##--------------------------------------------------------## | |
# get the length of each seq in your batch | |
seq_lengths = LongTensor(list(map(len, vectorized_seqs))) | |
# seq_lengths => [ 8, 4, 6] | |
# batch_sum_seq_len: 8 + 4 + 6 = 18 | |
# max_seq_len: 8 | |
seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long() | |
# seq_tensor => [[0 0 0 0 0 0 0 0] | |
# [0 0 0 0 0 0 0 0] | |
# [0 0 0 0 0 0 0 0]] | |
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)): | |
seq_tensor[idx, :seqlen] = LongTensor(seq) | |
# seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str | |
# [12 5 8 14 0 0 0 0] # tiny | |
# [ 7 3 2 5 13 7 0 0]] # medium | |
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8) | |
## Step 5: Sort instances by sequence length in descending order ## | |
##---------------------------------------------------------------## | |
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) | |
seq_tensor = seq_tensor[perm_idx] | |
# seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str | |
# [ 7 3 2 5 13 7 0 0] # medium | |
# [12 5 8 14 0 0 0 0]] # tiny | |
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8) | |
## Step 6: Embed the instances ## | |
##-----------------------------## | |
embedded_seq_tensor = embed(seq_tensor) | |
# embedded_seq_tensor => | |
# [[[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l | |
# [-0.23622951 2.0361056 0.15435742 -0.04513785] o | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n | |
# [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] g | |
# [-1.6334635 -0.6100042 1.7509955 -1.931793 ] _ | |
# [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] s | |
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] t | |
# [-0.22739866 -0.45782727 -0.6643252 0.25129375]] r | |
# [[ 0.16031227 -0.08209462 -0.16297023 0.48121014] m | |
# [-0.7303265 -0.857339 0.58913064 -1.1068314 ] e | |
# [ 0.48159844 -1.4886451 0.92639893 0.76906884] d | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i | |
# [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] u | |
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] m | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]] <pad> | |
# [[ 0.64004815 0.45813003 0.3476034 -0.03451729] t | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n | |
# [-1.284392 0.68294704 1.4064184 -0.42879772] y | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad> | |
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]]] <pad> | |
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4) | |
## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ## | |
##-------------------------------------------------------------------------------## | |
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True) | |
# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes | |
# | |
# packed_input.data => | |
# [[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l | |
# [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] m | |
# [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] t | |
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] o | |
# [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] e | |
# [-1.284392 0.68294704 1.4064184 -0.42879772] i | |
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] n | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] d | |
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] n | |
# [-0.23622951 2.0361056 0.15435742 -0.04513785] g | |
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] i | |
# [-0.22739866 -0.45782727 -0.6643252 0.25129375]] y | |
# [-0.7303265 -0.857339 0.58913064 -1.1068314 ] _ | |
# [-1.6334635 -0.6100042 1.7509955 -1.931793 ] u | |
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] s | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] m | |
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] t | |
# [ 0.48159844 -1.4886451 0.92639893 0.76906884] r | |
# packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4) | |
# | |
# packed_input.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] | |
# visualization : | |
# l o n g _ s t r #(long_str) | |
# m e d i u m #(medium) | |
# t i n y #(tiny) | |
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len]) | |
## Step 8: Forward with LSTM ## | |
##---------------------------## | |
packed_output, (ht, ct) = lstm(packed_input) | |
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes | |
# | |
# packed_output.data : | |
# [[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l | |
# [ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] o | |
# [ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] n | |
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] g | |
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] _ | |
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] s | |
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] t | |
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] r | |
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] m | |
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] e | |
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] d | |
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] i | |
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] u | |
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] m | |
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] t | |
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] i | |
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] n | |
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] y | |
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5) | |
# packed_output.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] (same as packed_input.batch_sizes) | |
# visualization : | |
# l o n g _ s t r #(long_str) | |
# m e d i u m #(medium) | |
# t i n y #(tiny) | |
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len]) | |
## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ## | |
##------------------------------------------------------------------------------------## | |
# unpack your output if required | |
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True) | |
# output: | |
# output => | |
# [[[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l | |
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o | |
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n | |
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g | |
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _ | |
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s | |
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t | |
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r | |
# [[ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m | |
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e | |
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d | |
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i | |
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u | |
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ]] <pad> | |
# [[ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t | |
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i | |
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n | |
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ] <pad> | |
# [ 0. 0. 0. 0. 0. ]]] <pad> | |
# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5) | |
# Or if you just want the final hidden state? | |
print(ht[-1]) | |
## Summary of Shape Transformations ## | |
##----------------------------------## | |
# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim) | |
# (batch_size X max_seq_len X embedding_dim) ---> Pack ---> (batch_sum_seq_len X embedding_dim) | |
# (batch_sum_seq_len X embedding_dim) ---> LSTM ---> (batch_sum_seq_len X hidden_dim) | |
# (batch_sum_seq_len X hidden_dim) ---> UnPack ---> (batch_size X max_seq_len X hidden_dim) |
Perfectly explained! Was always confused on what data goes into the batch.
Why sort instances by sequence length in descending order step is needed?
pack_padded_sequence does not need sorting anymore,its a parameter in the function (Doc)
**enforce_sorted** (bool, optional) –
if True, the input is expected to contain sequences sorted by length in a decreasing order. If False, the input will get sorted unconditionally. Default: True.
@jackfrost29 - len
is a built-in method in classes. When calling len
, it accesses the __len__
method for whatever object is used as input. The usual understanding with len
is that it finds the length / size of whatever object you pass to it. In this case, the object is a list of token lists so it finds the length of every token list in vectorized_seqs
.
Wonder why nobody complains about lines 120-138, as the packed sequence is clearly wrong.
Clearly, the first three rows in the packed sequence are not l, m, t but l, u, s for example. There are also too many closing brackets in line 132.
Pretty helpful, thank you
Thankyou very much.It's a very important paper.
you sort them, then you need back to original position right? I want to use a hidden state, is that right?
''' a_lengths, idx = text_length.sort(0, descending=True)
_, un_idx = t.sort(idx, dim=0)
seq = text[idx]
seq = self.dropout(self.embedding(seq))
a_packed_input = t.nn.utils.rnn.pack_padded_sequence(input=seq, lengths=a_lengths.to('cpu'), batch_first=True)
packed_output, (hidden, cell) = self.rnn(a_packed_input)
out, _ = t.nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
hidden = self.dropout(t.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
hidden = t.index_select(hidden, 0, un_idx)
just what i was looking for, thanks
I can't find any performance comparision. Did anyone compare using pack_padded_sequence
with just padded sequence?
Why sort instances by sequence length in descending order step is needed?
If you want to export this model as ONNX, enforce_sorted option must be True.
However, if this model is not to be used in production, you can set enforce_sorted=False to avoid sorting.
Superb!
Very helpful!
Most easy-to-understand explanation I have read !
awnsome !
I think packed_output
is wrong. It should be the same order as the packed_input
. Otherwise, calling pad_packed_sequence
will generate inconsistent behavior between packed_output
and packed_input
.
See this example
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
x = F.one_hot(torch.arange(9) % 3, num_classes=3).reshape(3, 3, 3).float()
length = torch.tensor([3, 1, 2], dtype=torch.int64)
print(x)
# tensor([[[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]],
# [[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]],
# [[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]]])
packed_x = pack_padded_sequence(x, length, batch_first=True, enforce_sorted=False)
print(packed_x)
# PackedSequence(data=tensor([[1., 0., 0.],
# [1., 0., 0.],
# [1., 0., 0.],
# [0., 1., 0.],
# [0., 1., 0.],
# [0., 0., 1.]]), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))
m = nn.RNN(3, 1, batch_first=True, bias=False)
print(m.all_weights)
# [[Parameter containing:
# tensor([[-0.7161, 0.8613, -0.8458]], requires_grad=True), Parameter containing:
# tensor([[0.2222]], requires_grad=True)]]
packed_lstm_out, _ = m.forward(packed_x)
print(packed_lstm_out)
# PackedSequence(data=tensor([[-0.6145],
# [-0.6145],
# [-0.6145],
# [ 0.6198],
# [ 0.6198],
# [-0.6094]], grad_fn=<CatBackward0>), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))
unpacked_lstm_out, unpacked_length = pad_packed_sequence(packed_lstm_out, batch_first=True)
print(unpacked_lstm_out)
# tensor([[[-0.6145],
# [ 0.6198],
# [-0.6094]],
# [[-0.6145],
# [ 0.0000],
# [ 0.0000]],
# [[-0.6145],
# [ 0.6198],
# [ 0.0000]]], grad_fn=<IndexSelectBackward0>)
This is very helpful. Thank you.
The explanation for packed_output, (ht, ct) = lstm(packed_input)
seems not correct.
should be:
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
#
# packed_output.data :
# [[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l
# [ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m
# [ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)
# packed_output.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] (same as packed_input.batch_sizes)
# visualization :
# l o n g _ s t r #(long_str)
# m e d i u m #(medium)
# t i n y #(tiny)
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len])
It's really help me to understand
Bro where did the len object in line 51 come from?