Last active
January 12, 2023 12:43
-
-
Save dayyass/bfee98e5cd66f6e25363f34e90ed3591 to your computer and use it in GitHub Desktop.
RNN inference time with/without pack_padded_sequence comparison.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
# hyper-parameters | |
BATCH_SIZE = 2 | |
SEQ_LEN = 3 | |
EMBEDDING_DIM = 5 | |
RNN_HIDDEN_SIZE = 5 | |
PADDING_VALUE = 0.0 | |
# init two embeddings in batch: first full-length and second padded | |
full_length_idx, padded_idx = 0, 1 | |
lengths = torch.tensor([SEQ_LEN, SEQ_LEN - 1]) | |
embeddings = torch.randn(BATCH_SIZE, SEQ_LEN, EMBEDDING_DIM) | |
# init recurrent layer | |
rnn = nn.RNN(input_size=EMBEDDING_DIM, hidden_size=RNN_HIDDEN_SIZE, batch_first=True) | |
# get recurrent layer output for embeddings | |
output, last_hidden = rnn(embeddings) | |
last_hidden = last_hidden.transpose(0, 1) # transpose to [batch_size, seq_len, emb_dim] (batch_first=True) | |
# pack embeddings | |
embeddings_packed = pack_padded_sequence(embeddings, lengths=lengths, batch_first=True) | |
# get recurrent layer output for packed embeddings | |
output_packed, last_hidden_packed = rnn(embeddings_packed) | |
last_hidden_packed = last_hidden_packed.transpose(0, 1) # transpose to [batch_size, seq_len, emb_dim] (batch_first=True) | |
# unpack recurrent layer output | |
output_packed, _ = pad_packed_sequence(output_packed, batch_first=True, padding_value=PADDING_VALUE) | |
### COMPARISON AND RESULTS ### | |
# rnn outputs and last hidden state with/without pack_padded_sequence are equal for full-length batch element | |
comparison_1_1 = torch.allclose( # True | |
output[full_length_idx], | |
output_packed[full_length_idx], | |
) | |
comparison_1_2 = torch.allclose( # True | |
last_hidden[full_length_idx], | |
last_hidden_packed[full_length_idx], | |
) | |
print(f"1) full-length batch element outputs equivalence: {comparison_1_1}") | |
print(f"1) full-length batch element last hidden state equivalence: {comparison_1_2}") | |
# rnn outputs and last hidden state with/without pack_padded_sequence are different for padded batch element | |
comparison_2_1 = torch.allclose( # False | |
output[padded_idx], | |
output_packed[padded_idx], | |
) | |
comparison_2_2 = torch.allclose( # False | |
last_hidden[padded_idx], | |
last_hidden_packed[padded_idx], | |
) | |
print(f"2) padded batch element outputs equivalence: {comparison_2_1}") | |
print(f"2) padded batch element last hidden state equivalence: {comparison_2_2}") | |
# pack_padded_sequence fills output pad indicies with pad value | |
comparison_3 = torch.allclose( # True | |
output_packed[padded_idx][lengths[padded_idx]], | |
torch.full(size=(EMBEDDING_DIM,), fill_value=PADDING_VALUE), | |
) | |
print(f"3) pad indicies fills with pad value: {comparison_3}") | |
# pack_padded_sequence gives last hidden state before padding (non-default behavior) | |
comparison_4_1 = torch.allclose( # False | |
output[padded_idx][lengths[padded_idx] - 1], | |
last_hidden[padded_idx], | |
) | |
comparison_4_2 = torch.allclose( # True | |
output_packed[padded_idx][lengths[padded_idx] - 1], | |
last_hidden_packed[padded_idx], | |
) | |
print(f"4) padded batch element last hidden state and last output before padding equivalence: {comparison_4_1}") | |
print(f"4) padded batch element last hidden state and last output before padding packed equivalence: {comparison_4_2}") | |
### CONCLUSION ### | |
# 1) rnn outputs and last hidden state with/without pack_padded_sequence are equal for full-length batch element | |
# 2) rnn outputs and last hidden state with/without pack_padded_sequence are different for padded batch element | |
# 3) pack_padded_sequence fills output pad indicies with pad values | |
# 4) pack_padded_sequence gives last hidden state before padding (non-default behavior) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment