Last active
March 6, 2024 03:48
-
-
Save guschmue/e35157f1f13f51585c60da47400b359e to your computer and use it in GitHub Desktop.
diff onnx models node by node
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
from __future__ import print_function | |
from __future__ import unicode_literals | |
import argparse | |
import logging | |
import traceback | |
import numpy as np | |
import onnx | |
from onnx import ModelProto, helper, onnx_pb, numpy_helper | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("onnx-experiments") | |
def get_args(): | |
"""Parse commandline.""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input", required=True, help="onnx input model file") | |
parser.add_argument("--output", help="output model file") | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = get_args() | |
with open(args.input, "rb") as f: | |
data = f.read() | |
model_proto = ModelProto() | |
model_proto.ParseFromString(data) | |
model_proto = onnx.shape_inference.infer_shapes(model_proto) | |
g = model_proto.graph | |
tensor_values = {t.name: t for t in g.value_info} | |
outputs = {} | |
for o in g.output: | |
outputs[o.name] = o | |
for n in g.node: | |
for o in n.output: | |
if o not in outputs: | |
# outputs[o] = helper.make_tensor_value_info(o, onnx_pb.TensorProto.UNDEFINED, None) | |
# outputs[o] = helper.make_tensor_value_info(o, onnx_pb.TensorProto.FLOAT, None) | |
if o in tensor_values: | |
outputs[o] = tensor_values[o] | |
else: | |
print(f"{o} not in shape inference") | |
while len(g.output): | |
g.output.pop() | |
for o in outputs.values(): | |
g.output.append(o) | |
if args.output: | |
with open(args.output, "wb") as f: | |
f.write(model_proto.SerializeToString()) | |
for o in outputs.keys(): | |
print(o) | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import sys | |
import numpy as np | |
import onnxruntime as onnxrt | |
float_dict = { | |
'tensor(float16)': 'float16', | |
'tensor(float)': 'float32', | |
'tensor(double)': 'float64' | |
} | |
integer_dict = { | |
'tensor(int32)': 'int32', | |
'tensor(int8)': 'int8', | |
'tensor(uint8)': 'uint8', | |
'tensor(int16)': 'int16', | |
'tensor(uint16)': 'uint16', | |
'tensor(int64)': 'int64', | |
'tensor(uint64)': 'uint64' | |
} | |
def get_args(): | |
"""Parse commandline.""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input", required=True, help="onnx input model file") | |
parser.add_argument("--node", help="node names to pull") | |
args = parser.parse_args() | |
args.input = args.input.split(",") | |
if args.node: | |
args.node = args.node.split(",") | |
return args | |
def make_feed(sess): | |
np.random.seed(1) | |
feeds = {} | |
for input_meta in sess.get_inputs(): | |
# replace any symbolic dimensions (value is None) with 1 | |
shape = [dim if dim and not isinstance(dim, str) else 1 for dim in | |
input_meta.shape] | |
if input_meta.type in float_dict: | |
feeds[input_meta.name] = np.random.rand(*shape).astype(float_dict[input_meta.type]) | |
elif input_meta.type in integer_dict: | |
feeds[input_meta.name] = np.random.uniform(high=1000, size=tuple(shape)).astype( | |
integer_dict[input_meta.type]) | |
elif input_meta.type == 'tensor(bool)': | |
feeds[input_meta.name] = np.random.randint(2, size=tuple(shape)).astype('bool') | |
else: | |
print("unsupported input type {} for input {}".format(input_meta.type, input_meta.name)) | |
sys.exit(-1) | |
return feeds | |
def main(): | |
args = get_args() | |
sess1 = onnxrt.InferenceSession(args.input[0]) | |
sess2 = onnxrt.InferenceSession(args.input[1]) | |
# for meta in sess2.get_outputs(): | |
# print(meta) | |
feeds = make_feed(sess1) | |
res1 = sess1.run([args.node[0]], feeds) # fetch all outputs | |
feeds = make_feed(sess2) | |
res2 = sess2.run([args.node[1]], feeds) # fetch all outputs | |
np.testing.assert_allclose(res1, res2, rtol=0.1) | |
print(f"ok {args.node[0]}, {args.node[1]}") | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment