Created
November 25, 2020 16:49
-
-
Save dayyass/6d8f9f85f22a7d8e4179e18f624a652f to your computer and use it in GitHub Desktop.
ONNX doesn't support PyTorch Adaptive Pooling (and Global Pooling as a special case with output_size=1). There is an implementation of Global Pooling compatible with ONNX.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torch.nn as nn | |
import onnx | |
import onnxruntime | |
##### INIT 1d, 2d, 3d GLOBAL POOLING MODULES ##### | |
class GlobalAvgPool1d(nn.Module): | |
""" | |
Reduce mean over last dimension. | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
return x.mean(dim=-1, keepdim=True) | |
class GlobalMaxPool1d(nn.Module): | |
""" | |
Reduce max over last dimension. | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
return x.max(dim=-1, keepdim=True)[0] | |
class GlobalAvgPool2d(nn.Module): | |
""" | |
Reduce mean over last two dimensions. | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
x = x.mean(dim=-1, keepdim=True) | |
return x.mean(dim=-2, keepdim=True) | |
class GlobalMaxPool2d(nn.Module): | |
""" | |
Reduce max over last two dimensions. | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
x = x.max(dim=-1, keepdim=True)[0] | |
return x.max(dim=-2, keepdim=True)[0] | |
class GlobalAvgPool3d(nn.Module): | |
""" | |
Reduce mean over last three dimensions. | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
x = x.mean(dim=-1, keepdim=True) | |
x = x.mean(dim=-2, keepdim=True) | |
return x.mean(dim=-3, keepdim=True) | |
class GlobalMaxPool3d(nn.Module): | |
""" | |
Reduce max over last three dimensions. | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
x = x.max(dim=-1, keepdim=True)[0] | |
x = x.max(dim=-2, keepdim=True)[0] | |
return x.max(dim=-3, keepdim=True)[0] | |
##### EXAMPLE OF ONNX EXPORT ##### | |
global_pooling = GlobalMaxPool2d() # init global pooling layer | |
# input to the global pooling layer | |
tensor = torch.randn(1, 1, 224, 224) # init first two dimensions with ones to allow dynamic axes | |
torch_out = global_pooling(tensor) # torch inference | |
# export the global pooling layer | |
torch.onnx.export( | |
model=global_pooling, # model being run | |
args=tensor, # model input (or a tuple for multiple inputs) | |
f="global_pooling.onnx", # where to save the model (can be a file or file-like object) | |
export_params=True, # store the trained parameter weights inside the model file | |
opset_version=10, # the ONNX version to export the model to | |
do_constant_folding=True, # whether to execute constant folding for optimization | |
input_names=['input'], # the model's input names | |
output_names=['output'], # the model's output names | |
dynamic_axes={ # variable length axes | |
'input': {0: 'batch_size', 1: 'n_channel', 2: 'height', 3: 'width'}, | |
'output': {0: 'batch_size', 1: 'n_channel'}, | |
}, | |
) | |
onnx_model = onnx.load('global_pooling.onnx') # load onnx global pooling layer | |
onnx.checker.check_model(onnx_model) # verify the model’s structure and confirm that the model has a valid schema | |
ort_session = onnxruntime.InferenceSession('global_pooling.onnx') # create an inference session of onnxruntime | |
# compute ONNX Runtime output prediction | |
ort_inputs = {'input': tensor.numpy()} | |
ort_outs = ort_session.run(None, ort_inputs) | |
# compare ONNX Runtime and PyTorch results | |
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05) | |
print('Exported model has been tested with ONNXRuntime, and the result looks good!') | |
# check to inference tensor with different shape | |
tensor = torch.randn(2, 3, 128, 256) | |
torch_out = global_pooling(tensor) # torch inference | |
ort_inputs = {'input': tensor.numpy()} # onnx input | |
ort_outs = ort_session.run(None, ort_inputs) # onnx inference | |
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05) # compare torch with onnx | |
print('ONNX can work with arbitrary dimension tensor!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment