Skip to content

Instantly share code, notes, and snippets.

@zachmayer
Last active June 7, 2024 22:03
Show Gist options
  • Save zachmayer/c3964fc107418a4398fc3ec1434b2943 to your computer and use it in GitHub Desktop.
Save zachmayer/c3964fc107418a4398fc3ec1434b2943 to your computer and use it in GitHub Desktop.
Export surya layout model to ONNX
# Copyright 2024 AI INSIGHT SOLUTIONS INC.
import timeit
from pathlib import Path
import numpy as np
import onnx
import onnxruntime as rt
import torch
from onnxruntime.quantization import QuantType, quantize_dynamic
from onnxruntime.quantization.preprocess import quant_pre_process
from optimum.onnxruntime import ORTModelForSemanticSegmentation
from surya.model.detection.segformer import load_model, load_processor
# Settings
model_name = "vikp/surya_layout2"
onnx_folder = Path("~/Downloads/").expanduser() / model_name
onnx_folder.mkdir(parents=True, exist_ok=True)
file_onnx = "model.onnx"
file_prep_onnx = "model_prep.onnx"
file_opt_onnx = "model_prep_opt.onnx"
file_quant_onnx = "model_prep_opt_quant.onnx"
onnx_model_class = ORTModelForSemanticSegmentation
# Load model
model = load_model(checkpoint=model_name).eval()
processor = load_processor(checkpoint=model_name)
# Create an example input
np.random.seed(42)
sample_pixels = processor.preprocess(
images=np.random.randint(0, 255, (512, 512, 3)), return_tensors="pt"
)["pixel_values"]
# Export the model to ONNX
model.config.save_pretrained(onnx_folder)
torch.onnx.export(
model,
sample_pixels,
f=onnx_folder / file_onnx,
input_names=["pixel_values"],
output_names=["logits"],
dynamic_axes={
"pixel_values": {0: "batch_size", 2: "height", 3: "width"},
"logits": {0: "batch_size"},
},
)
# Preprocess for quantization
quant_pre_process(
input_model_path=onnx_folder / file_onnx,
output_model_path=onnx_folder / file_prep_onnx,
auto_merge=True,
all_tensors_to_one_file=True,
)
# Optimize the model
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = (
rt.GraphOptimizationLevel.ORT_ENABLE_BASIC
) # ORT_ENABLE_BASIC / ORT_ENABLE_EXTENDED / ORT_ENABLE_ALL
sess_options.optimized_model_filepath = str(onnx_folder / file_opt_onnx)
session = rt.InferenceSession(str(onnx_folder / file_prep_onnx), sess_options)
# Quantize the ONNX model
quantized_model = quantize_dynamic(
model_input=onnx_folder / file_opt_onnx,
model_output=onnx_folder / file_quant_onnx,
weight_type=QuantType.QUInt8,
)
# Compare the graphs in onnx
original_onnx_model = onnx.load(onnx_folder / file_onnx)
prep_onnx_model = onnx.load(onnx_folder / file_prep_onnx)
optimized_onnx_model = onnx.load(onnx_folder / file_opt_onnx)
quantized_onnx_model = onnx.load(onnx_folder / file_quant_onnx)
print(f"Onnx model graph: {len(original_onnx_model.graph.node)}")
print(f"Prep model graph: {len(prep_onnx_model.graph.node)}")
print(f"Optimized model graph: {len(optimized_onnx_model.graph.node)}")
print(f"Quantized model graph: {len(quantized_onnx_model.graph.node)}")
# Load models for inference
model_onnx = onnx_model_class.from_pretrained(onnx_folder, file_name=file_onnx)
model_prep = onnx_model_class.from_pretrained(onnx_folder, file_name=file_prep_onnx)
model_opt = onnx_model_class.from_pretrained(onnx_folder, file_name=file_opt_onnx)
model_quant = onnx_model_class.from_pretrained(onnx_folder, file_name=file_quant_onnx)
# Run inference
outputs = model(sample_pixels)["logits"].detach().numpy()
onnx_outputs = model_onnx(sample_pixels)["logits"].detach().numpy()
prep_outputs = model_prep(sample_pixels)["logits"].detach().numpy()
opt_outputs = model_opt(sample_pixels)["logits"].detach().numpy()
quant_outputs = model_quant(sample_pixels)["logits"].detach().numpy()
# Check inference
assert outputs.shape[1] == len(model.config.id2label)
assert (
outputs.shape
== onnx_outputs.shape
== prep_outputs.shape
== opt_outputs.shape
== quant_outputs.shape
)
assert np.allclose(outputs, onnx_outputs, rtol=0.0, atol=1e-5)
assert np.allclose(outputs, prep_outputs, rtol=0.0, atol=1e-5)
assert np.allclose(outputs, opt_outputs, rtol=0.0, atol=1e-5)
assert np.allclose(outputs, quant_outputs, rtol=0.02, atol=0.1)
# Timings
reps = 10
orig_time = timeit.timeit("model(sample_pixels)", globals=globals(), number=reps)
onnx_time = timeit.timeit("model_onnx(sample_pixels)", globals=globals(), number=reps)
prep_time = timeit.timeit("model_prep(sample_pixels)", globals=globals(), number=reps)
opt_time = timeit.timeit("model_opt(sample_pixels)", globals=globals(), number=reps)
quant_time = timeit.timeit("model_quant(sample_pixels)", globals=globals(), number=reps)
# Show results
print(f"Original Model Time: {(orig_time/reps):.2f} seconds per rep")
print(f"Onnx Model Time: {(onnx_time/reps):.2f} seconds per rep")
print(f"Preprocessed Onnx Model Time: {(prep_time/reps):.2f} seconds per rep")
print(f"Optimized Onnx Model Time: {(opt_time/reps):.2f} seconds per rep")
print(f"Quantized Onnx Model Time: {(quant_time/reps):.2f} seconds per rep")
@zachmayer
Copy link
Author

File size:

  • original model: 120 mb
  • onnx model: 120.5 mb
  • optimized onnx model: 120.3 mb
  • quantized model: 31 mb

Inference speed (on CPU on my M2 mac):

  • original model: 0.54 seconds per image
  • onnx model: 0.26 seconds per image
  • optimized onnx model: 0.25 seconds per image
  • quantized model: 0.40 seconds per image

Inference accuracy:

  • original model: NA
  • onnx model: Almost identical
  • optimized onnx model: Almost identical
  • quantized model: some differences

Conclusions:

  • onnx for inference is 50% faster, same file size
  • optimizing the onnx model give a small speedup on top of that, for the same file size
  • quanztizing the model gives a 25% speedup and 75% smaller model, at the cost of some changes in the output and accuracy

All said, I'm not convinced on the benefits of quantizing, unless you really need a smaller file size (e.g. you can use a 31 MB file in a git repo without LFS)

using onnx for inference with the optimizations seems worth it though. Faster inference, same file size, basically the same outputs.

@zachmayer
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment