Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created March 3, 2018 01:05
Show Gist options
  • Save jamesr66a/a8e4584ced2ef9563bda9bfc7d6fb67e to your computer and use it in GitHub Desktop.
Save jamesr66a/a8e4584ced2ef9563bda9bfc7d6fb67e to your computer and use it in GitHub Desktop.
import torch
import torch.onnx
class MyCastModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = x.int()
return torch.add(x, x)
mcm = MyCastModule()
mcm.eval()
x = torch.zeros(1, 2, 3).long()
torch.onnx._export(mcm, x, 'test.onnx', verbose=True)
################
$ python casttest.py
Traceback (most recent call last):
File "casttest.py", line 17, in <module>
torch.onnx._export(mcm, x, 'test.onnx', verbose=True)
File "/Users/jamesreed/onnx-fairseq/pytorch/torch/onnx/__init__.py", line 132, in _export
trace, torch_out = torch.jit.get_trace_graph(model, args)
File "/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py", line 252, in get_trace_graph
return LegacyTracedModule(f, nderivs=nderivs)(*args, **kwargs)
File "/Users/jamesreed/onnx-fairseq/pytorch/torch/nn/modules/module.py", line 363, in __call__
result = self.forward(*input, **kwargs)
File "/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py", line 288, in forward
torch._C._tracer_exit(out_vars)
RuntimeError: /Users/jamesreed/onnx-fairseq/pytorch/torch/csrc/jit/tracer.h:117: getTracingState: Assertion `state` failed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment