Skip to content

Instantly share code, notes, and snippets.

@dayyass
Last active January 12, 2023 12:43
Show Gist options
  • Save dayyass/bfee98e5cd66f6e25363f34e90ed3591 to your computer and use it in GitHub Desktop.
Save dayyass/bfee98e5cd66f6e25363f34e90ed3591 to your computer and use it in GitHub Desktop.
RNN inference time with/without pack_padded_sequence comparison.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# hyper-parameters
BATCH_SIZE = 2
SEQ_LEN = 3
EMBEDDING_DIM = 5
RNN_HIDDEN_SIZE = 5
PADDING_VALUE = 0.0
# init two embeddings in batch: first full-length and second padded
full_length_idx, padded_idx = 0, 1
lengths = torch.tensor([SEQ_LEN, SEQ_LEN - 1])
embeddings = torch.randn(BATCH_SIZE, SEQ_LEN, EMBEDDING_DIM)
# init recurrent layer
rnn = nn.RNN(input_size=EMBEDDING_DIM, hidden_size=RNN_HIDDEN_SIZE, batch_first=True)
# get recurrent layer output for embeddings
output, last_hidden = rnn(embeddings)
last_hidden = last_hidden.transpose(0, 1) # transpose to [batch_size, seq_len, emb_dim] (batch_first=True)
# pack embeddings
embeddings_packed = pack_padded_sequence(embeddings, lengths=lengths, batch_first=True)
# get recurrent layer output for packed embeddings
output_packed, last_hidden_packed = rnn(embeddings_packed)
last_hidden_packed = last_hidden_packed.transpose(0, 1) # transpose to [batch_size, seq_len, emb_dim] (batch_first=True)
# unpack recurrent layer output
output_packed, _ = pad_packed_sequence(output_packed, batch_first=True, padding_value=PADDING_VALUE)
### COMPARISON AND RESULTS ###
# rnn outputs and last hidden state with/without pack_padded_sequence are equal for full-length batch element
comparison_1_1 = torch.allclose( # True
output[full_length_idx],
output_packed[full_length_idx],
)
comparison_1_2 = torch.allclose( # True
last_hidden[full_length_idx],
last_hidden_packed[full_length_idx],
)
print(f"1) full-length batch element outputs equivalence: {comparison_1_1}")
print(f"1) full-length batch element last hidden state equivalence: {comparison_1_2}")
# rnn outputs and last hidden state with/without pack_padded_sequence are different for padded batch element
comparison_2_1 = torch.allclose( # False
output[padded_idx],
output_packed[padded_idx],
)
comparison_2_2 = torch.allclose( # False
last_hidden[padded_idx],
last_hidden_packed[padded_idx],
)
print(f"2) padded batch element outputs equivalence: {comparison_2_1}")
print(f"2) padded batch element last hidden state equivalence: {comparison_2_2}")
# pack_padded_sequence fills output pad indicies with pad value
comparison_3 = torch.allclose( # True
output_packed[padded_idx][lengths[padded_idx]],
torch.full(size=(EMBEDDING_DIM,), fill_value=PADDING_VALUE),
)
print(f"3) pad indicies fills with pad value: {comparison_3}")
# pack_padded_sequence gives last hidden state before padding (non-default behavior)
comparison_4_1 = torch.allclose( # False
output[padded_idx][lengths[padded_idx] - 1],
last_hidden[padded_idx],
)
comparison_4_2 = torch.allclose( # True
output_packed[padded_idx][lengths[padded_idx] - 1],
last_hidden_packed[padded_idx],
)
print(f"4) padded batch element last hidden state and last output before padding equivalence: {comparison_4_1}")
print(f"4) padded batch element last hidden state and last output before padding packed equivalence: {comparison_4_2}")
### CONCLUSION ###
# 1) rnn outputs and last hidden state with/without pack_padded_sequence are equal for full-length batch element
# 2) rnn outputs and last hidden state with/without pack_padded_sequence are different for padded batch element
# 3) pack_padded_sequence fills output pad indicies with pad values
# 4) pack_padded_sequence gives last hidden state before padding (non-default behavior)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment