Last active
October 27, 2024 15:17
-
-
Save Tushar-N/dfca335e370a2bc3bc79876e6270099e to your computer and use it in GitHub Desktop.
How to use pad_packed_sequence in pytorch<1.1.0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
seqs = ['gigantic_string','tiny_str','medium_str'] | |
# make <pad> idx 0 | |
vocab = ['<pad>'] + sorted(set(''.join(seqs))) | |
# make model | |
embed = nn.Embedding(len(vocab), 10).cuda() | |
lstm = nn.LSTM(10, 5).cuda() | |
vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs] | |
# get the length of each seq in your batch | |
seq_lengths = torch.LongTensor([len(seq) for seq in vectorized_seqs]).cuda() | |
# dump padding everywhere, and place seqs on the left. | |
# NOTE: you only need a tensor as big as your longest sequence | |
seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long().cuda() | |
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)): | |
seq_tensor[idx, :seqlen] = torch.LongTensor(seq) | |
# SORT YOUR TENSORS BY LENGTH! | |
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) | |
seq_tensor = seq_tensor[perm_idx] | |
# utils.rnn lets you give (B,L,D) tensors where B is the batch size, L is the maxlength, if you use batch_first=True | |
# Otherwise, give (L,B,D) tensors | |
seq_tensor = seq_tensor.transpose(0,1) # (B,L,D) -> (L,B,D) | |
# embed your sequences | |
seq_tensor = embed(seq_tensor) | |
# pack them up nicely | |
packed_input = pack_padded_sequence(seq_tensor, seq_lengths.cpu().numpy()) | |
# throw them through your LSTM (remember to give batch_first=True here if you packed with it) | |
packed_output, (ht, ct) = lstm(packed_input) | |
# unpack your output if required | |
output, _ = pad_packed_sequence(packed_output) | |
print (output) | |
# Or if you just want the final hidden state? | |
print (ht[-1]) | |
# REMEMBER: Your outputs are sorted. If you want the original ordering | |
# back (to compare to some gt labels) unsort them | |
_, unperm_idx = perm_idx.sort(0) | |
output = output[unperm_idx] | |
print (output) |
Since pytorch 1.1.0, sorting the sequences by their lengths is no longer needed: pytorch/pytorch#15225.
As an exercise, I tried to replicate this and the version by @HarshTrivedi, maybe it would be useful to someone (although I recommend the two mentioned above more): https://gist.github.com/MikulasZelinka/9fce4ed47ae74fca454e88a39f8d911a (also includes a very basic Dataset
and DataLoader
example).
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Just like @DarryO and @icesuns said, if you want the original ordering, transpose output first.