Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Last active August 29, 2018 21:10
Show Gist options
  • Save wanchaol/ba9026c7f804ab9369aa589046c1cac6 to your computer and use it in GitHub Desktop.
Save wanchaol/ba9026c7f804ab9369aa589046c1cac6 to your computer and use it in GitHub Desktop.
pack_padded tracing
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import itertools
import tempfile
def pack_pad_seq(seq_tensor, seq_lengths):
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
seq_tensor = seq_tensor.transpose(0,1) # (B,L,D) -> (L,B,D)
seq_tensor = embed(seq_tensor)
packed_input = pack_padded_sequence(seq_tensor, seq_lengths)
packed_output, (ht, ct) = lstm(packed_input)
return packed_output
def flatten(l):
return list(itertools.chain.from_iterable(l))
if __name__ == "__main__":
seqs = ['gigantic_string','tiny_str','medium_str']
vocab = ['<pad>'] + sorted(list(set(flatten(seqs))))
embed = nn.Embedding(len(vocab), 10)
lstm = nn.LSTM(10, 5)
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]
seq_lengths = torch.LongTensor([len(v) for v in vectorized_seqs])
seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long()
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
pack_pad_seq(seq_tensor, seq_lengths)
fn_traced = torch.jit.trace(seq_tensor, seq_lengths)(pack_pad_seq)
print(fn_traced.graph)
m = torch.jit.ScriptModule()
m._create_method_from_graph("forward", fn_traced.graph)
f = tempfile.NamedTemporaryFile(delete=True)
m.save(f.name)
==============
Traceback (most recent call last):
File "test_padd.py", line 66, in <module>
m.save(f.name)
RuntimeError: Couldn't export Python operator pack_padded_sequence_trace_wrapper
Defined at:
/data/users/wanchaol/pytorch/torch/nn/utils/rnn.py(172): _symbolic_pack_padded_sequence
/data/users/wanchaol/pytorch/torch/onnx/__init__.py(102): wrapper
test_padd.py(24): pack_pad_seq
/data/users/wanchaol/pytorch/torch/jit/__init__.py(290): wrapper
test_padd.py(60): <module>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment