-
-
Save jtourille/b6baa30b98d30c4dc1241099f9ed0a35 to your computer and use it in GitHub Desktop.
How to use pad_packed_sequence in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
import torch.nn.functional as F | |
import numpy as np | |
import itertools | |
def flatten(l): | |
return list(itertools.chain.from_iterable(l)) | |
seqs = ['ghatmasala','nicela','chutpakodas'] | |
# make <pad> idx 0 | |
vocab = ['<pad>'] + sorted(list(set(flatten(seqs)))) | |
# make model | |
embed = nn.Embedding(len(vocab), 10).cuda() | |
lstm = nn.LSTM(10, 5).cuda() | |
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs] | |
# get the length of each seq in your batch | |
seq_lengths = torch.cuda.LongTensor(map(len, vectorized_seqs)) | |
# dump padding everywhere, and place seqs on the left. | |
# NOTE: you only need a tensor as big as your longest sequence | |
seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long().cuda() | |
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)): | |
seq_tensor[idx, :seqlen] = torch.LongTensor(seq) | |
# SORT YOUR TENSORS BY LENGTH! | |
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) | |
seq_tensor = seq_tensor[perm_idx] | |
# utils.rnn lets you give (B,L,D) tensors where B is the batch size, L is the maxlength, if you use batch_first=True | |
# Otherwise, give (L,B,D) tensors | |
seq_tensor = seq_tensor.transpose(0,1) # (B,L,D) -> (L,B,D) | |
# embed your sequences | |
seq_tensor = embed(seq_tensor) | |
# pack them up nicely | |
packed_input = pack_padded_sequence(seq_tensor, seq_lengths.cpu().numpy()) | |
# throw them through your LSTM (remember to give batch_first=True here if you packed with it) | |
packed_output, (ht, ct) = lstm(packed_input) | |
# unpack your output if required | |
output, _ = pad_packed_sequence(packed_output) | |
print output | |
# Or if you just want the final hidden state? | |
print ht[-1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment