Created
January 10, 2022 14:45
-
-
Save romain-keramitas-prl/bf43254d15055f104e4c63abd504e476 to your computer and use it in GitHub Desktop.
Code to reproduct ORT Gather noder discrepancy.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from onnxruntime import InferenceSession | |
import torch | |
import torch.nn as nn | |
# Create simple model with one Gather node | |
model = nn.Embedding(num_embeddings=10, embedding_dim=5) | |
x = torch.randint(0, 10, (3,)) | |
torch.onnx.export( | |
model, | |
args=(x,), | |
f="model.onnx", | |
input_names=["input_ids"], | |
output_names=["output"], | |
opset_version=14, | |
dynamic_axes={"input_ids": {0: "batch_size"}, "output": {0: "batch_size"}} | |
) | |
for provider in ["CPUExecutionProvider", "CUDAExecutionProvider"]: | |
print(f"testing {provider} ...") | |
session = InferenceSession("model.onnx", providers=[provider]) | |
io_binding = session.io_binding() | |
io_binding.bind_cpu_input("input_ids", np.ones(3, dtype=np.int64) * 10) | |
io_binding.bind_output("output") | |
try: | |
session.run_with_iobinding(io_binding) | |
print(f"{provider} did not raise an error") | |
except Exception as e: | |
print(f"{provider} raised the following error:\n{e}") | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment