Created
September 6, 2018 18:27
-
-
Save wanchaol/4df5720b33e2a91e7eeea05ccbefdc0d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class Test(torch.nn.Module): | |
def __init__(self): | |
super(Test, self).__init__() | |
def forward(self, input): | |
# y = input.size(0) + 1 | |
y = 0 | |
for i in range(10): | |
y = input[0:i] | |
return y | |
test_traced = torch.jit.trace(Test(), torch.rand(12,4)) | |
print(type(test_traced)) | |
print(test_traced.graph) | |
example_outputs = Test()(torch.rand(12, 4)) | |
print('tracing translation export') | |
import io | |
f = io.BytesIO() | |
torch.onnx._export(Test(), (torch.rand(12, 4),), f, verbose=True, example_outputs=example_outputs) | |
print('-----------------') | |
@torch.jit.script | |
def foo(x): | |
y = _to_tensor(0) | |
for i in range(10): | |
y = x[0:i] | |
return y | |
print(foo.graph) | |
ex_outputs = foo(torch.rand(12, 4)) | |
print('scripting translation export') | |
f2 = io.BytesIO() | |
torch.onnx._export(foo, (torch.rand(12, 4),), f2, verbose=True, example_outputs=ex_outputs) | |
================= | |
<class 'torch.jit.TopLevelTracedModule'> | |
graph(%0 : Float(12, 4)) { | |
%46 : int = prim::Constant[value=0](), scope: Test | |
%47 : int = prim::Constant[value=0](), scope: Test | |
%48 : int = prim::Constant[value=9](), scope: Test | |
%49 : int = prim::Constant[value=1](), scope: Test | |
%50 : Float(9, 4) = aten::slice(%0, %46, %47, %48, %49), scope: Test | |
return (%50); | |
} | |
tracing translation export | |
graph(%0 : Float(12, 4)) { | |
%1 : Float(9, 4) = onnx::Slice[axes=[0], ends=[9], starts=[0]](%0), scope: Test | |
return (%1); | |
} | |
----------------- | |
graph(%x : Dynamic) { | |
%3 : int = prim::Constant[value=10]() | |
%1 : int = prim::Constant[value=0]() | |
%y.1 : Long() = prim::NumToTensor(%1) | |
%4 : int = prim::Constant[value=1]() | |
%y : Long() = prim::Loop(%3, %4, %y.1) | |
block0(%i : int, %10 : Long()) { | |
%7 : int = prim::Constant[value=0]() | |
%8 : int = prim::Constant[value=1]() | |
%y.2 : Dynamic = aten::slice(%x, %7, %1, %i, %8) | |
%11 : int = prim::Constant[value=1]() | |
-> (%11, %y.2) | |
} | |
return (%y); | |
} | |
scripting translation export | |
Traceback (most recent call last): | |
File "/home/wanchaol/test_slice.py", line 41, in <module> | |
torch.onnx._export(foo, (torch.rand(12, 4),), f2, verbose=True, example_outputs=ex_outputs) | |
File "/data/users/wanchaol/pytorch/torch/onnx/__init__.py", line 22, in _export | |
return utils._export(*args, **kwargs) | |
File "/data/users/wanchaol/pytorch/torch/onnx/utils.py", line 280, in _export | |
example_outputs, propagate) | |
File "/data/users/wanchaol/pytorch/torch/onnx/utils.py", line 226, in _model_to_graph | |
graph = _optimize_graph(graph, operator_export_type) | |
File "/data/users/wanchaol/pytorch/torch/onnx/utils.py", line 154, in _optimize_graph | |
graph = torch._C._jit_pass_onnx(graph, operator_export_type) | |
File "/data/users/wanchaol/pytorch/torch/onnx/__init__.py", line 52, in _run_symbolic_function | |
return utils._run_symbolic_function(*args, **kwargs) | |
File "/data/users/wanchaol/pytorch/torch/onnx/utils.py", line 533, in _run_symbolic_function | |
torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env) | |
File "/data/users/wanchaol/pytorch/torch/onnx/__init__.py", line 52, in _run_symbolic_function | |
return utils._run_symbolic_function(*args, **kwargs) | |
File "/data/users/wanchaol/pytorch/torch/onnx/utils.py", line 503, in _run_symbolic_function | |
return fn(g, *inputs, **attrs) | |
File "/data/users/wanchaol/pytorch/torch/onnx/symbolic.py", line 87, in wrapper | |
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] | |
File "/data/users/wanchaol/pytorch/torch/onnx/symbolic.py", line 87, in <listcomp> | |
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] | |
File "/data/users/wanchaol/pytorch/torch/onnx/symbolic.py", line 44, in _parse_arg | |
raise RuntimeError("ONNX symbolic expected a constant value in the trace") | |
RuntimeError: ONNX symbolic expected a constant value in the trace |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment