Created
November 15, 2023 18:41
-
-
Save johncf/3f663976ac1b892d488076fc739f643d to your computer and use it in GitHub Desktop.
Inspect an ONNX model file for sanity
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import traceback | |
import onnx | |
if len(sys.argv) != 2: | |
print(f"Usage: {sys.argv[0]} /path/to/model.onnx") | |
sys.exit(1) | |
model_path = sys.argv[1] | |
model = onnx.load(model_path) | |
input_all = {inp.name: inp for inp in model.graph.input} | |
print('all inputs count:', len(input_all)) | |
initializer_names = set(init.name for init in model.graph.initializer) | |
feed_input_names = set(input_all) - initializer_names | |
print('required inputs:') | |
for inp_name in feed_input_names: | |
print(input_all[inp_name]) | |
print() | |
output = model.graph.output | |
print('outputs:', output, end="\n\n") | |
onnx.checker.check_model(model, full_check=True) | |
print("Model check successful!\n") | |
try: | |
onnx.checker.check_graph(model.graph) | |
print("Model graph check successful!\n") | |
except Exception: | |
print("Model graph check failed with error:") | |
print(traceback.format_exc()) | |
try: | |
onnx.utils.extract_model( | |
model_path, | |
"/tmp/extracted_model.onnx", | |
input_names=list(feed_input_names), | |
output_names=[outp.name for outp in model.graph.output], | |
) | |
print("Model extraction successful!\n") | |
except Exception: | |
print("Model extraction failed with error:") | |
print(traceback.format_exc()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment