Skip to content

Instantly share code, notes, and snippets.

@dayyass
Last active January 12, 2023 12:43
Show Gist options
  • Save dayyass/bfee98e5cd66f6e25363f34e90ed3591 to your computer and use it in GitHub Desktop.
Save dayyass/bfee98e5cd66f6e25363f34e90ed3591 to your computer and use it in GitHub Desktop.
RNN inference time with/without pack_padded_sequence comparison.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn.utils.rnn import pack_padded_sequence"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# hyper-parameters\n",
"BATCH_SIZE = 256\n",
"SEQ_LEN = 50\n",
"EMBEDDING_DIM = 300\n",
"RNN_HIDDEN_SIZE = 300"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# init variable-length embeddings\n",
"lengths = torch.randint(low=1, high=SEQ_LEN + 1, size=(BATCH_SIZE,))\n",
"embeddings = torch.randn(BATCH_SIZE, SEQ_LEN, EMBEDDING_DIM)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# init recurrent layer\n",
"rnn = nn.RNN(input_size=EMBEDDING_DIM, hidden_size=RNN_HIDDEN_SIZE, batch_first=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"97.9 ms ± 3.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"# get recurrent layer output for embeddings\n",
"%timeit output, last_hidden = rnn(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# pack embeddings\n",
"embeddings_packed = pack_padded_sequence(embeddings, lengths=lengths, batch_first=True, enforce_sorted=False)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"52.4 ms ± 2.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"# get recurrent layer output for packed embeddings\n",
"%timeit output_packed, last_hidden_packed = rnn(embeddings_packed)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# RNN processes packed embeddings about two times faster than standard embeddings"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# hyper-parameters
BATCH_SIZE = 2
SEQ_LEN = 3
EMBEDDING_DIM = 5
RNN_HIDDEN_SIZE = 5
PADDING_VALUE = 0.0
# init two embeddings in batch: first full-length and second padded
full_length_idx, padded_idx = 0, 1
lengths = torch.tensor([SEQ_LEN, SEQ_LEN - 1])
embeddings = torch.randn(BATCH_SIZE, SEQ_LEN, EMBEDDING_DIM)
# init recurrent layer
rnn = nn.RNN(input_size=EMBEDDING_DIM, hidden_size=RNN_HIDDEN_SIZE, batch_first=True)
# get recurrent layer output for embeddings
output, last_hidden = rnn(embeddings)
last_hidden = last_hidden.transpose(0, 1) # transpose to [batch_size, seq_len, emb_dim] (batch_first=True)
# pack embeddings
embeddings_packed = pack_padded_sequence(embeddings, lengths=lengths, batch_first=True)
# get recurrent layer output for packed embeddings
output_packed, last_hidden_packed = rnn(embeddings_packed)
last_hidden_packed = last_hidden_packed.transpose(0, 1) # transpose to [batch_size, seq_len, emb_dim] (batch_first=True)
# unpack recurrent layer output
output_packed, _ = pad_packed_sequence(output_packed, batch_first=True, padding_value=PADDING_VALUE)
### COMPARISON AND RESULTS ###
# rnn outputs and last hidden state with/without pack_padded_sequence are equal for full-length batch element
comparison_1_1 = torch.allclose( # True
output[full_length_idx],
output_packed[full_length_idx],
)
comparison_1_2 = torch.allclose( # True
last_hidden[full_length_idx],
last_hidden_packed[full_length_idx],
)
print(f"1) full-length batch element outputs equivalence: {comparison_1_1}")
print(f"1) full-length batch element last hidden state equivalence: {comparison_1_2}")
# rnn outputs and last hidden state with/without pack_padded_sequence are different for padded batch element
comparison_2_1 = torch.allclose( # False
output[padded_idx],
output_packed[padded_idx],
)
comparison_2_2 = torch.allclose( # False
last_hidden[padded_idx],
last_hidden_packed[padded_idx],
)
print(f"2) padded batch element outputs equivalence: {comparison_2_1}")
print(f"2) padded batch element last hidden state equivalence: {comparison_2_2}")
# pack_padded_sequence fills output pad indicies with pad value
comparison_3 = torch.allclose( # True
output_packed[padded_idx][lengths[padded_idx]],
torch.full(size=(EMBEDDING_DIM,), fill_value=PADDING_VALUE),
)
print(f"3) pad indicies fills with pad value: {comparison_3}")
# pack_padded_sequence gives last hidden state before padding (non-default behavior)
comparison_4_1 = torch.allclose( # False
output[padded_idx][lengths[padded_idx] - 1],
last_hidden[padded_idx],
)
comparison_4_2 = torch.allclose( # True
output_packed[padded_idx][lengths[padded_idx] - 1],
last_hidden_packed[padded_idx],
)
print(f"4) padded batch element last hidden state and last output before padding equivalence: {comparison_4_1}")
print(f"4) padded batch element last hidden state and last output before padding packed equivalence: {comparison_4_2}")
### CONCLUSION ###
# 1) rnn outputs and last hidden state with/without pack_padded_sequence are equal for full-length batch element
# 2) rnn outputs and last hidden state with/without pack_padded_sequence are different for padded batch element
# 3) pack_padded_sequence fills output pad indicies with pad values
# 4) pack_padded_sequence gives last hidden state before padding (non-default behavior)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment