Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Created September 6, 2018 18:27
Show Gist options
  • Save wanchaol/4df5720b33e2a91e7eeea05ccbefdc0d to your computer and use it in GitHub Desktop.
Save wanchaol/4df5720b33e2a91e7eeea05ccbefdc0d to your computer and use it in GitHub Desktop.
import torch
class Test(torch.nn.Module):
def __init__(self):
super(Test, self).__init__()
def forward(self, input):
# y = input.size(0) + 1
y = 0
for i in range(10):
y = input[0:i]
return y
test_traced = torch.jit.trace(Test(), torch.rand(12,4))
print(type(test_traced))
print(test_traced.graph)
example_outputs = Test()(torch.rand(12, 4))
print('tracing translation export')
import io
f = io.BytesIO()
torch.onnx._export(Test(), (torch.rand(12, 4),), f, verbose=True, example_outputs=example_outputs)
print('-----------------')
@torch.jit.script
def foo(x):
y = _to_tensor(0)
for i in range(10):
y = x[0:i]
return y
print(foo.graph)
ex_outputs = foo(torch.rand(12, 4))
print('scripting translation export')
f2 = io.BytesIO()
torch.onnx._export(foo, (torch.rand(12, 4),), f2, verbose=True, example_outputs=ex_outputs)
=================
<class 'torch.jit.TopLevelTracedModule'>
graph(%0 : Float(12, 4)) {
%46 : int = prim::Constant[value=0](), scope: Test
%47 : int = prim::Constant[value=0](), scope: Test
%48 : int = prim::Constant[value=9](), scope: Test
%49 : int = prim::Constant[value=1](), scope: Test
%50 : Float(9, 4) = aten::slice(%0, %46, %47, %48, %49), scope: Test
return (%50);
}
tracing translation export
graph(%0 : Float(12, 4)) {
%1 : Float(9, 4) = onnx::Slice[axes=[0], ends=[9], starts=[0]](%0), scope: Test
return (%1);
}
-----------------
graph(%x : Dynamic) {
%3 : int = prim::Constant[value=10]()
%1 : int = prim::Constant[value=0]()
%y.1 : Long() = prim::NumToTensor(%1)
%4 : int = prim::Constant[value=1]()
%y : Long() = prim::Loop(%3, %4, %y.1)
block0(%i : int, %10 : Long()) {
%7 : int = prim::Constant[value=0]()
%8 : int = prim::Constant[value=1]()
%y.2 : Dynamic = aten::slice(%x, %7, %1, %i, %8)
%11 : int = prim::Constant[value=1]()
-> (%11, %y.2)
}
return (%y);
}
scripting translation export
Traceback (most recent call last):
File "/home/wanchaol/test_slice.py", line 41, in <module>
torch.onnx._export(foo, (torch.rand(12, 4),), f2, verbose=True, example_outputs=ex_outputs)
File "/data/users/wanchaol/pytorch/torch/onnx/__init__.py", line 22, in _export
return utils._export(*args, **kwargs)
File "/data/users/wanchaol/pytorch/torch/onnx/utils.py", line 280, in _export
example_outputs, propagate)
File "/data/users/wanchaol/pytorch/torch/onnx/utils.py", line 226, in _model_to_graph
graph = _optimize_graph(graph, operator_export_type)
File "/data/users/wanchaol/pytorch/torch/onnx/utils.py", line 154, in _optimize_graph
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
File "/data/users/wanchaol/pytorch/torch/onnx/__init__.py", line 52, in _run_symbolic_function
return utils._run_symbolic_function(*args, **kwargs)
File "/data/users/wanchaol/pytorch/torch/onnx/utils.py", line 533, in _run_symbolic_function
torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
File "/data/users/wanchaol/pytorch/torch/onnx/__init__.py", line 52, in _run_symbolic_function
return utils._run_symbolic_function(*args, **kwargs)
File "/data/users/wanchaol/pytorch/torch/onnx/utils.py", line 503, in _run_symbolic_function
return fn(g, *inputs, **attrs)
File "/data/users/wanchaol/pytorch/torch/onnx/symbolic.py", line 87, in wrapper
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
File "/data/users/wanchaol/pytorch/torch/onnx/symbolic.py", line 87, in <listcomp>
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
File "/data/users/wanchaol/pytorch/torch/onnx/symbolic.py", line 44, in _parse_arg
raise RuntimeError("ONNX symbolic expected a constant value in the trace")
RuntimeError: ONNX symbolic expected a constant value in the trace
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment