Skip to content

Instantly share code, notes, and snippets.

@jjfiv
Last active June 19, 2018 18:36
Show Gist options
  • Select an option

  • Save jjfiv/ca2bc462162cef1867e9053bf22ea0b3 to your computer and use it in GitHub Desktop.

Select an option

Save jjfiv/ca2bc462162cef1867e9053bf22ea0b3 to your computer and use it in GitHub Desktop.
Pack sequences within a batch for a pytorch LSTM.
import numpy as np
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
def pack_lstm(items, lstm):
N = len(items)
reorder_args = np.argsort([len(it) for it in items])[::-1]
origin_args = torch.from_numpy(np.argsort(reorder_args))
ordered = [items[i] for i in reorder_args]
packed_items = pack_padded_sequence(pad_sequence(ordered, batch_first=True), [len(od) for od in ordered], batch_first=True)
_, (hn, _) = lstm(packed_items)
by_inst_repr = hn.transpose(0,1).reshape(N,-1)
# Now untwizzle
return torch.index_select(by_inst_repr, 0, origin_args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment