Skip to content

Instantly share code, notes, and snippets.

Last active October 27, 2024 15:17
Show Gist options
  • Save Tushar-N/dfca335e370a2bc3bc79876e6270099e to your computer and use it in GitHub Desktop.
Save Tushar-N/dfca335e370a2bc3bc79876e6270099e to your computer and use it in GitHub Desktop.
How to use pad_packed_sequence in pytorch<1.1.0
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
seqs = ['gigantic_string','tiny_str','medium_str']
# make <pad> idx 0
vocab = ['<pad>'] + sorted(set(''.join(seqs)))
# make model
embed = nn.Embedding(len(vocab), 10).cuda()
lstm = nn.LSTM(10, 5).cuda()
vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]
# get the length of each seq in your batch
seq_lengths = torch.LongTensor([len(seq) for seq in vectorized_seqs]).cuda()
# dump padding everywhere, and place seqs on the left.
# NOTE: you only need a tensor as big as your longest sequence
seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long().cuda()
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
# utils.rnn lets you give (B,L,D) tensors where B is the batch size, L is the maxlength, if you use batch_first=True
# Otherwise, give (L,B,D) tensors
seq_tensor = seq_tensor.transpose(0,1) # (B,L,D) -> (L,B,D)
# embed your sequences
seq_tensor = embed(seq_tensor)
# pack them up nicely
packed_input = pack_padded_sequence(seq_tensor, seq_lengths.cpu().numpy())
# throw them through your LSTM (remember to give batch_first=True here if you packed with it)
packed_output, (ht, ct) = lstm(packed_input)
# unpack your output if required
output, _ = pad_packed_sequence(packed_output)
print (output)
# Or if you just want the final hidden state?
print (ht[-1])
# REMEMBER: Your outputs are sorted. If you want the original ordering
# back (to compare to some gt labels) unsort them
_, unperm_idx = perm_idx.sort(0)
output = output[unperm_idx]
print (output)
Copy link

yuchenlin commented Jul 4, 2017

great code!
but i wonder why pytorch does not just design a util function which just receives a fixed-length list of variable-length sequences and output a padded&packed variable...
it's so complicated now, though

Copy link

lgalke commented Aug 31, 2017

Great demo code! So you don't need to bother with padding_idx of Embedding to ignore the zeros, because the packing does not even show them to the lstm?

Copy link

@Tushar-N thanks for this gist, it's awesome. In fact it inspired myself to write this little demo with some visualizations for a better understanding on batching inputs into a LSTM where I featured your code. Many thanks :) Cheers.

Copy link

hunkim commented Nov 2, 2017

Very cool!

Copy link

Cheneng commented Dec 23, 2017

Help a lot! Thanks!

Copy link


Copy link

Huh, github never notified me about comments on the gist. Well, better late than never.
@lgalke That's right, you don't have to worry about padding_idx
@ngarneau nice demo! And everyone else, glad I could help :)

Copy link

Can we feed (L,B,D) dimension to embedding layer? The docs say the first dimension should be mini batch size.

Copy link

@nikhiltitus You can. Embedding expects a (N,W) tensor, but it pulls out an embedding for each element anyway.

Copy link

datduong commented Mar 24, 2018

Hi, I don't understand this part,

# throw them through your LSTM (remember to give batch_first=True here if you packed with it)
packed_output, (ht, ct) = lstm(packed_input)

I used packed_input = pack_padded_sequence(seq_tensor, seq_lengths.numpy() , batch_first=True ), then I tried packed_output, (ht, ct) = lstm(packed_input,batch_first=True) and get

TypeError: forward() got an unexpected keyword argument 'batch_first'


Copy link

batch_first argument is only for initialization of LSTM, forward() doesn't need that.

Copy link

DuaneNielsen commented Apr 8, 2018

I ran this up and got the following error in python 3...

TypeError: torch.cuda.LongTensor constructor received an invalid combination of arguments - got (map), but expected one of...

The fix was to change line 24

seq_lengths = torch.cuda.LongTensor(map(len, vectorized_seqs))


seq_lengths = torch.cuda.LongTensor(list(map(len, vectorized_seqs)))

Guess they messed with the way maps work

Copy link

jojonki commented Apr 16, 2018

seq_lengths = torch.LongTensor([len(seq) for seq in vectorized_seqs]) also works

Copy link

jizg commented May 16, 2018

Great demo, very helpful. I also used this way in my work. Thanks

Copy link

It is really helpful!! Thanks very much!!

Copy link

Thank you!

Copy link

HarshTrivedi commented Jul 30, 2018

@Tushar-N Wonderful! A minimal example explaining everything, thanks! Here (and here) is a much verbose version of this. I think, ascii drawings would make it much simpler to visualize and understand what's happening inside.

Copy link

allanj commented Aug 14, 2018

Great understanding

Agree that line 24 should be changed

Copy link

@ngarneau Your demo was really helpful. Thank you very much !!

Copy link

icesuns commented Nov 21, 2018

# REMEMBER: Your outputs are sorted. If you want the original ordering
# back (to compare to some gt labels) unsort them
_, unperm_idx = perm_idx.sort(0)
output = output[unperm_idx]
print (output)

if you want to get the original ordering, you should add script "output = output.transpose(1, 0)"
otherwise, the index will be out bounds of dimenssion of outout

Copy link

Since the perm_idx is obtained by lengths, should we use the following code to do reverse?

output = output.transpose(0, 1)  # L x B x D -> B x L x D
hidden = hidden.transpose(0, 1)
output = output[unperm_idx]
hidden = hidden[unperm_idx]

Copy link

tang1943 commented Apr 8, 2019

Just like @DarryO and @icesuns said, if you want the original ordering, transpose output first.

Copy link

m-k-l-s commented May 7, 2019

Since pytorch 1.1.0, sorting the sequences by their lengths is no longer needed: pytorch/pytorch#15225.

As an exercise, I tried to replicate this and the version by @HarshTrivedi, maybe it would be useful to someone (although I recommend the two mentioned above more): (also includes a very basic Dataset and DataLoader example).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment