Last active
June 7, 2024 22:03
-
-
Save zachmayer/c3964fc107418a4398fc3ec1434b2943 to your computer and use it in GitHub Desktop.
Export surya layout model to ONNX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2024 AI INSIGHT SOLUTIONS INC. | |
import timeit | |
from pathlib import Path | |
import numpy as np | |
import onnx | |
import onnxruntime as rt | |
import torch | |
from onnxruntime.quantization import QuantType, quantize_dynamic | |
from onnxruntime.quantization.preprocess import quant_pre_process | |
from optimum.onnxruntime import ORTModelForSemanticSegmentation | |
from surya.model.detection.segformer import load_model, load_processor | |
# Settings | |
model_name = "vikp/surya_layout2" | |
onnx_folder = Path("~/Downloads/").expanduser() / model_name | |
onnx_folder.mkdir(parents=True, exist_ok=True) | |
file_onnx = "model.onnx" | |
file_prep_onnx = "model_prep.onnx" | |
file_opt_onnx = "model_prep_opt.onnx" | |
file_quant_onnx = "model_prep_opt_quant.onnx" | |
onnx_model_class = ORTModelForSemanticSegmentation | |
# Load model | |
model = load_model(checkpoint=model_name).eval() | |
processor = load_processor(checkpoint=model_name) | |
# Create an example input | |
np.random.seed(42) | |
sample_pixels = processor.preprocess( | |
images=np.random.randint(0, 255, (512, 512, 3)), return_tensors="pt" | |
)["pixel_values"] | |
# Export the model to ONNX | |
model.config.save_pretrained(onnx_folder) | |
torch.onnx.export( | |
model, | |
sample_pixels, | |
f=onnx_folder / file_onnx, | |
input_names=["pixel_values"], | |
output_names=["logits"], | |
dynamic_axes={ | |
"pixel_values": {0: "batch_size", 2: "height", 3: "width"}, | |
"logits": {0: "batch_size"}, | |
}, | |
) | |
# Preprocess for quantization | |
quant_pre_process( | |
input_model_path=onnx_folder / file_onnx, | |
output_model_path=onnx_folder / file_prep_onnx, | |
auto_merge=True, | |
all_tensors_to_one_file=True, | |
) | |
# Optimize the model | |
sess_options = rt.SessionOptions() | |
sess_options.graph_optimization_level = ( | |
rt.GraphOptimizationLevel.ORT_ENABLE_BASIC | |
) # ORT_ENABLE_BASIC / ORT_ENABLE_EXTENDED / ORT_ENABLE_ALL | |
sess_options.optimized_model_filepath = str(onnx_folder / file_opt_onnx) | |
session = rt.InferenceSession(str(onnx_folder / file_prep_onnx), sess_options) | |
# Quantize the ONNX model | |
quantized_model = quantize_dynamic( | |
model_input=onnx_folder / file_opt_onnx, | |
model_output=onnx_folder / file_quant_onnx, | |
weight_type=QuantType.QUInt8, | |
) | |
# Compare the graphs in onnx | |
original_onnx_model = onnx.load(onnx_folder / file_onnx) | |
prep_onnx_model = onnx.load(onnx_folder / file_prep_onnx) | |
optimized_onnx_model = onnx.load(onnx_folder / file_opt_onnx) | |
quantized_onnx_model = onnx.load(onnx_folder / file_quant_onnx) | |
print(f"Onnx model graph: {len(original_onnx_model.graph.node)}") | |
print(f"Prep model graph: {len(prep_onnx_model.graph.node)}") | |
print(f"Optimized model graph: {len(optimized_onnx_model.graph.node)}") | |
print(f"Quantized model graph: {len(quantized_onnx_model.graph.node)}") | |
# Load models for inference | |
model_onnx = onnx_model_class.from_pretrained(onnx_folder, file_name=file_onnx) | |
model_prep = onnx_model_class.from_pretrained(onnx_folder, file_name=file_prep_onnx) | |
model_opt = onnx_model_class.from_pretrained(onnx_folder, file_name=file_opt_onnx) | |
model_quant = onnx_model_class.from_pretrained(onnx_folder, file_name=file_quant_onnx) | |
# Run inference | |
outputs = model(sample_pixels)["logits"].detach().numpy() | |
onnx_outputs = model_onnx(sample_pixels)["logits"].detach().numpy() | |
prep_outputs = model_prep(sample_pixels)["logits"].detach().numpy() | |
opt_outputs = model_opt(sample_pixels)["logits"].detach().numpy() | |
quant_outputs = model_quant(sample_pixels)["logits"].detach().numpy() | |
# Check inference | |
assert outputs.shape[1] == len(model.config.id2label) | |
assert ( | |
outputs.shape | |
== onnx_outputs.shape | |
== prep_outputs.shape | |
== opt_outputs.shape | |
== quant_outputs.shape | |
) | |
assert np.allclose(outputs, onnx_outputs, rtol=0.0, atol=1e-5) | |
assert np.allclose(outputs, prep_outputs, rtol=0.0, atol=1e-5) | |
assert np.allclose(outputs, opt_outputs, rtol=0.0, atol=1e-5) | |
assert np.allclose(outputs, quant_outputs, rtol=0.02, atol=0.1) | |
# Timings | |
reps = 10 | |
orig_time = timeit.timeit("model(sample_pixels)", globals=globals(), number=reps) | |
onnx_time = timeit.timeit("model_onnx(sample_pixels)", globals=globals(), number=reps) | |
prep_time = timeit.timeit("model_prep(sample_pixels)", globals=globals(), number=reps) | |
opt_time = timeit.timeit("model_opt(sample_pixels)", globals=globals(), number=reps) | |
quant_time = timeit.timeit("model_quant(sample_pixels)", globals=globals(), number=reps) | |
# Show results | |
print(f"Original Model Time: {(orig_time/reps):.2f} seconds per rep") | |
print(f"Onnx Model Time: {(onnx_time/reps):.2f} seconds per rep") | |
print(f"Preprocessed Onnx Model Time: {(prep_time/reps):.2f} seconds per rep") | |
print(f"Optimized Onnx Model Time: {(opt_time/reps):.2f} seconds per rep") | |
print(f"Quantized Onnx Model Time: {(quant_time/reps):.2f} seconds per rep") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
File size:
Inference speed (on CPU on my M2 mac):
Inference accuracy:
Conclusions:
All said, I'm not convinced on the benefits of quantizing, unless you really need a smaller file size (e.g. you can use a 31 MB file in a git repo without LFS)
using onnx for inference with the optimizations seems worth it though. Faster inference, same file size, basically the same outputs.