Skip to content

Instantly share code, notes, and snippets.

@HarshTrivedi
Forked from Tushar-N/pad_packed_demo.py
Last active June 9, 2026 16:25
Show Gist options
  • Select an option

  • Save HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec to your computer and use it in GitHub Desktop.

Select an option

Save HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec to your computer and use it in GitHub Desktop.
Minimal tutorial on packing (pack_padded_sequence) and unpacking (pad_packed_sequence) sequences in pytorch.
import torch
from torch import LongTensor
from torch.nn import Embedding, LSTM
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium']
#
# Step 1: Construct Vocabulary
# Step 2: Load indexed data (list of instances, where each instance is list of character indices)
# Step 3: Make Model
# * Step 4: Pad instances with 0s till max length sequence
# * Step 5: Sort instances by sequence length in descending order
# * Step 6: Embed the instances
# * Step 7: Call pack_padded_sequence with embeded instances and sequence lengths
# * Step 8: Forward with LSTM
# * Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector
# * Summary of Shape Transformations
# We want to run LSTM on a batch following 3 character sequences
seqs = ['long_str', # len = 8
'tiny', # len = 4
'medium'] # len = 6
## Step 1: Construct Vocabulary ##
##------------------------------##
# make sure <pad> idx is 0
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))
# => ['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']
## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ##
##-------------------------------------------------------------------------------------------------##
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]
# vectorized_seqs => [[6, 9, 8, 4, 1, 11, 12, 10],
# [12, 5, 8, 14],
# [7, 3, 2, 5, 13, 7]]
## Step 3: Make Model ##
##--------------------##
embed = Embedding(len(vocab), 4) # embedding_dim = 4
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5
## Step 4: Pad instances with 0s till max length sequence ##
##--------------------------------------------------------##
# get the length of each seq in your batch
seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
# seq_lengths => [ 8, 4, 6]
# batch_sum_seq_len: 8 + 4 + 6 = 18
# max_seq_len: 8
seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()
# seq_tensor => [[0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 0 0]]
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
seq_tensor[idx, :seqlen] = LongTensor(seq)
# seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str
# [12 5 8 14 0 0 0 0] # tiny
# [ 7 3 2 5 13 7 0 0]] # medium
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)
## Step 5: Sort instances by sequence length in descending order ##
##---------------------------------------------------------------##
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
# seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str
# [ 7 3 2 5 13 7 0 0] # medium
# [12 5 8 14 0 0 0 0]] # tiny
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)
## Step 6: Embed the instances ##
##-----------------------------##
embedded_seq_tensor = embed(seq_tensor)
# embedded_seq_tensor =>
# [[[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l
# [-0.23622951 2.0361056 0.15435742 -0.04513785] o
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n
# [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] g
# [-1.6334635 -0.6100042 1.7509955 -1.931793 ] _
# [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] s
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] t
# [-0.22739866 -0.45782727 -0.6643252 0.25129375]] r
# [[ 0.16031227 -0.08209462 -0.16297023 0.48121014] m
# [-0.7303265 -0.857339 0.58913064 -1.1068314 ] e
# [ 0.48159844 -1.4886451 0.92639893 0.76906884] d
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i
# [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] u
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] m
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad>
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]] <pad>
# [[ 0.64004815 0.45813003 0.3476034 -0.03451729] t
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n
# [-1.284392 0.68294704 1.4064184 -0.42879772] y
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad>
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad>
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad>
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]]] <pad>
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4)
## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ##
##-------------------------------------------------------------------------------##
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
#
# packed_input.data =>
# [[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l
# [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] m
# [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] t
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] o
# [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] e
# [-1.284392 0.68294704 1.4064184 -0.42879772] i
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] n
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] d
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] n
# [-0.23622951 2.0361056 0.15435742 -0.04513785] g
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] i
# [-0.22739866 -0.45782727 -0.6643252 0.25129375]] y
# [-0.7303265 -0.857339 0.58913064 -1.1068314 ] _
# [-1.6334635 -0.6100042 1.7509955 -1.931793 ] u
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] s
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] m
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] t
# [ 0.48159844 -1.4886451 0.92639893 0.76906884] r
# packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4)
#
# packed_input.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1]
# visualization :
# l o n g _ s t r #(long_str)
# m e d i u m #(medium)
# t i n y #(tiny)
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len])
## Step 8: Forward with LSTM ##
##---------------------------##
packed_output, (ht, ct) = lstm(packed_input)
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
#
# packed_output.data :
# [[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l
# [ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] o
# [ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] n
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] g
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] _
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] s
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] t
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] r
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] m
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] e
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] d
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] i
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] u
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] m
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] t
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] i
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] n
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] y
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)
# packed_output.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] (same as packed_input.batch_sizes)
# visualization :
# l o n g _ s t r #(long_str)
# m e d i u m #(medium)
# t i n y #(tiny)
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len])
## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ##
##------------------------------------------------------------------------------------##
# unpack your output if required
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
# output:
# output =>
# [[[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r
# [[ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m
# [ 0. 0. 0. 0. 0. ] <pad>
# [ 0. 0. 0. 0. 0. ]] <pad>
# [[ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y
# [ 0. 0. 0. 0. 0. ] <pad>
# [ 0. 0. 0. 0. 0. ] <pad>
# [ 0. 0. 0. 0. 0. ] <pad>
# [ 0. 0. 0. 0. 0. ]]] <pad>
# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5)
# Or if you just want the final hidden state?
print(ht[-1])
## Summary of Shape Transformations ##
##----------------------------------##
# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim)
# (batch_size X max_seq_len X embedding_dim) ---> Pack ---> (batch_sum_seq_len X embedding_dim)
# (batch_sum_seq_len X embedding_dim) ---> LSTM ---> (batch_sum_seq_len X hidden_dim)
# (batch_sum_seq_len X hidden_dim) ---> UnPack ---> (batch_size X max_seq_len X hidden_dim)
@nilinykh

Copy link
Copy Markdown

Great tutorial!
A question: in which situations we might want to use pad_packed_sequence? When calculating loss, wouldn't it be simpler to work with packed (without pads) scores (LSTM outputs) and packed targets? Or do I need/have to sometimes pad scores and targets using pad_packed_sequence? If yes, then when is it used?

@LearningHarder

Copy link
Copy Markdown

Thank you very much!

@kunalmessi10

Copy link
Copy Markdown

Great work!

@spookyQubit

Copy link
Copy Markdown

Thanks a lot for putting this together.

@RudRho

RudRho commented Jul 20, 2020

Copy link
Copy Markdown

Line#146 is the icing on the cake.

Awesome!

@shihanmax

Copy link
Copy Markdown

Great work!

@PhaneendraGunda

Copy link
Copy Markdown

Great work.

@realyanyang

Copy link
Copy Markdown

Thanks a lot! 👍

@mrnewman55

Copy link
Copy Markdown

thank you, it is very helpful!

@davidevegliante

Copy link
Copy Markdown

This is great! Congratulation

@jackfrost29

Copy link
Copy Markdown

Bro where did the len object in line 51 come from?

@rajy4683

Copy link
Copy Markdown

Perfectly explained! Was always confused on what data goes into the batch.

@laifi

laifi commented Nov 19, 2020

Copy link
Copy Markdown

Why sort instances by sequence length in descending order step is needed?

pack_padded_sequence does not need sorting anymore,its a parameter in the function (Doc)

**enforce_sorted** (bool, optional) –if True, the input is expected to contain sequences sorted by length in a decreasing order. If False, the input will get sorted unconditionally. Default: True.

@rayryeng

rayryeng commented Nov 22, 2020

Copy link
Copy Markdown

@jackfrost29 - len is a built-in method in classes. When calling len, it accesses the __len__ method for whatever object is used as input. The usual understanding with len is that it finds the length / size of whatever object you pass to it. In this case, the object is a list of token lists so it finds the length of every token list in vectorized_seqs.

@timeamagyar

Copy link
Copy Markdown

Wonder why nobody complains about lines 120-138, as the packed sequence is clearly wrong.

Clearly, the first three rows in the packed sequence are not l, m, t but l, u, s for example. There are also too many closing brackets in line 132.

@tombosc

tombosc commented Feb 25, 2021

Copy link
Copy Markdown

Pretty helpful, thank you

@duyupeng

Copy link
Copy Markdown

Thankyou very much.It's a very important paper.

@Dongximing

Copy link
Copy Markdown

you sort them, then you need back to original position right? I want to use a hidden state, is that right?
''' a_lengths, idx = text_length.sort(0, descending=True)
_, un_idx = t.sort(idx, dim=0)
seq = text[idx]

    seq = self.dropout(self.embedding(seq))

    a_packed_input = t.nn.utils.rnn.pack_padded_sequence(input=seq, lengths=a_lengths.to('cpu'), batch_first=True)
    packed_output, (hidden, cell) = self.rnn(a_packed_input)
    out, _ = t.nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
    hidden = self.dropout(t.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))

    hidden = t.index_select(hidden, 0, un_idx)

@lparsenadze

Copy link
Copy Markdown

just what i was looking for, thanks

@elch10

elch10 commented Dec 30, 2021

Copy link
Copy Markdown

I can't find any performance comparision. Did anyone compare using pack_padded_sequence with just padded sequence?

@davebulaval

Copy link
Copy Markdown

@Y-jiji

Y-jiji commented Mar 19, 2022

Copy link
Copy Markdown

Why sort instances by sequence length in descending order step is needed?

If you want to export this model as ONNX, enforce_sorted option must be True.
However, if this model is not to be used in production, you can set enforce_sorted=False to avoid sorting.

@jainraj

jainraj commented Apr 16, 2022

Copy link
Copy Markdown

Superb!

@JoyceXu02

Copy link
Copy Markdown

Very helpful!

@hungkien05

hungkien05 commented Jul 7, 2022

Copy link
Copy Markdown

Most easy-to-understand explanation I have read !

@nqchieutb01

Copy link
Copy Markdown

awnsome !

@gunnxx

gunnxx commented Dec 6, 2022

Copy link
Copy Markdown

I think packed_output is wrong. It should be the same order as the packed_input. Otherwise, calling pad_packed_sequence will generate inconsistent behavior between packed_output and packed_input.

See this example

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

x = F.one_hot(torch.arange(9) % 3, num_classes=3).reshape(3, 3, 3).float()
length = torch.tensor([3, 1, 2], dtype=torch.int64)
print(x)
# tensor([[[1., 0., 0.],
#          [0., 1., 0.],
#          [0., 0., 1.]],

#         [[1., 0., 0.],
#          [0., 1., 0.],
#          [0., 0., 1.]],

#         [[1., 0., 0.],
#          [0., 1., 0.],
#          [0., 0., 1.]]])

packed_x = pack_padded_sequence(x, length, batch_first=True, enforce_sorted=False)
print(packed_x)
# PackedSequence(data=tensor([[1., 0., 0.],
#         [1., 0., 0.],
#         [1., 0., 0.],
#         [0., 1., 0.],
#         [0., 1., 0.],
#         [0., 0., 1.]]), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))

m = nn.RNN(3, 1, batch_first=True, bias=False)
print(m.all_weights)
# [[Parameter containing:
# tensor([[-0.7161,  0.8613, -0.8458]], requires_grad=True), Parameter containing:
# tensor([[0.2222]], requires_grad=True)]]

packed_lstm_out, _ = m.forward(packed_x)
print(packed_lstm_out)
# PackedSequence(data=tensor([[-0.6145],
#         [-0.6145],
#         [-0.6145],
#         [ 0.6198],
#         [ 0.6198],
#         [-0.6094]], grad_fn=<CatBackward0>), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))

unpacked_lstm_out, unpacked_length = pad_packed_sequence(packed_lstm_out, batch_first=True)
print(unpacked_lstm_out)
# tensor([[[-0.6145],
#          [ 0.6198],
#          [-0.6094]],

#         [[-0.6145],
#          [ 0.0000],
#          [ 0.0000]],

#         [[-0.6145],
#          [ 0.6198],
#          [ 0.0000]]], grad_fn=<IndexSelectBackward0>)

@luyang-ai4med

Copy link
Copy Markdown

This is very helpful. Thank you.

@XianghengHee

XianghengHee commented Feb 7, 2024

Copy link
Copy Markdown

The explanation for packed_output, (ht, ct) = lstm(packed_input) seems not correct.

should be:

# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
#
# packed_output.data :
#                          [[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l
#                           [ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   m
#                           [ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   t
#                           [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   o
#                           [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   e
#                           [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   i
#                           [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   n
#                           [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   d
#                           [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   n
#                           [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   g
#                           [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   i
#                           [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   y
#                           [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   _
#                           [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   u
#                           [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   s
#                           [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   m
#                           [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   t
#                           [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  r
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)

# packed_output.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1] (same as packed_input.batch_sizes)
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

@Cendra123

Copy link
Copy Markdown

It's really help me to understand

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment