Created
June 27, 2019 14:04
-
-
Save iiSeymour/9c306e53adc0d0c3c92001263caffe5d to your computer and use it in GitHub Desktop.
Symbolic patching for exporting torch.{min,max} with onnx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.onnx.symbolic | |
def sym_patch(func, op1, op2): | |
def replace(g, node, dim_or_y=None, keepdim=None): | |
ops = func(g, node, dim_or_y, keepdim) | |
if type(ops) == tuple: | |
if ops[0].node().kind() == "onnx::ATen": | |
dim = torch.onnx.symbolic._get_const(dim_or_y, 'i', 'dim') | |
keepdim = torch.onnx.symbolic._get_const(keepdim, 'i', 'keepdim') | |
rmax = g.op(op1, node, axes_i=[dim], keepdims_i=keepdim) | |
indices = g.op(op2, node, axis_i=dim, keepdims_i=keepdim) | |
return rmax, indices | |
return ops | |
return replace | |
torch.onnx.symbolic.max = sym_patch(torch.onnx.symbolic.max, "ReduceMax", "ArgMax") | |
torch.onnx.symbolic.min = sym_patch(torch.onnx.symbolic.min, "ReduceMin", "ArgMin") | |
class MaxModel(torch.nn.Module): | |
def forward(self, x): | |
mx = torch.max(x, dim=1) | |
mn = torch.min(x, dim=1) | |
return mx, mn | |
x = torch.randn(4, 4) | |
model = MaxModel() | |
torch.onnx.export(model, x, "model.onnx", verbose=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
ONNX graph -