Created
July 29, 2025 20:29
-
-
Save josepsmartinez/361214d8fa28999d6995a5c07b39e866 to your computer and use it in GitHub Desktop.
Custom export of YOLOv5 to ONNX with NMS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import onnxruntime | |
import numpy as np | |
import onnx | |
from onnx import helper, numpy_helper | |
import onnx_graphsurgeon as gs | |
def add_nms(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
graph = gs.import_onnx(onnx_model) | |
boxes_var = graph.outputs[0] # (B, N, 4) | |
scores_var = graph.outputs[1] # (B, N, C) | |
# Set dynamic shape symbols (optional but recommended) | |
batch = "batch_size" | |
num_boxes = "num_boxes" | |
num_classes = "num_classes" | |
boxes_var.shape = [batch, num_boxes, 4] | |
scores_var.shape = [batch, num_boxes, num_classes] | |
# Transpose scores to (B, C, N) | |
scores_transposed = gs.Variable("scores_nms", dtype=np.float32) | |
transpose_node = gs.Node( | |
op="Transpose", | |
inputs=[scores_var], | |
outputs=[scores_transposed], | |
attrs={"perm": [0, 2, 1]}, | |
) | |
graph.nodes.append(transpose_node) | |
# Constants for NMS | |
max_output_boxes_per_class = gs.Constant("max_out", np.array([100], dtype=np.int64)) | |
iou_threshold = gs.Constant("iou_thresh", np.array([0.45], dtype=np.float32)) | |
score_threshold = gs.Constant("score_thresh", np.array([0.25], dtype=np.float32)) | |
# NMS node: output shape (num_detections, 3) = [batch_idx, class_idx, box_idx] | |
nms_out = gs.Variable("nms_indices", dtype=np.int64) | |
nms_node = gs.Node( | |
op="NonMaxSuppression", | |
inputs=[boxes_var, scores_transposed, max_output_boxes_per_class, iou_threshold, score_threshold], | |
outputs=[nms_out], | |
) | |
graph.nodes.append(nms_node) | |
# Split NMS output into batch_idx, class_idx, box_idx (each shape (num_detections,1)) | |
batch_idx = gs.Variable("batch_idx", dtype=np.int64) | |
class_idx = gs.Variable("class_idx", dtype=np.int64) | |
box_idx = gs.Variable("box_idx", dtype=np.int64) | |
split_sizes = gs.Constant("split_sizes", np.array([1, 1, 1], dtype=np.int64)) | |
split_node = gs.Node( | |
op="Split", | |
inputs=[nms_out, split_sizes], | |
outputs=[batch_idx, class_idx, box_idx], | |
attrs={"axis": 1}, | |
) | |
graph.nodes.append(split_node) | |
# Squeeze to 1D (num_detections,) | |
squeeze_axes = gs.Constant("squeeze_axes", np.array([1], dtype=np.int64)) | |
batch_idx_sq = gs.Variable("batch_idx_sq", dtype=np.int64) | |
squeeze_batch = gs.Node(op="Squeeze", inputs=[batch_idx, squeeze_axes], outputs=[batch_idx_sq]) | |
class_idx_sq = gs.Variable("class_idx_sq", dtype=np.int64) | |
squeeze_class = gs.Node(op="Squeeze", inputs=[class_idx, squeeze_axes], outputs=[class_idx_sq]) | |
box_idx_sq = gs.Variable("box_idx_sq", dtype=np.int64) | |
squeeze_box = gs.Node(op="Squeeze", inputs=[box_idx, squeeze_axes], outputs=[box_idx_sq]) | |
graph.nodes.extend([squeeze_batch, squeeze_class, squeeze_box]) | |
# Unsqueeze each index vector to (num_detections, 1) for concat | |
unsq_batch_idx = gs.Variable("unsq_batch_idx", dtype=np.int64) | |
unsq_box_idx = gs.Variable("unsq_box_idx", dtype=np.int64) | |
unsq_class_idx = gs.Variable("unsq_class_idx", dtype=np.int64) | |
unsq_axes = gs.Constant("unsq_axes", values=np.array([1], dtype=np.int64)) | |
unsqueeze_batch = gs.Node(op="Unsqueeze", inputs=[batch_idx_sq, unsq_axes], outputs=[unsq_batch_idx]) | |
unsqueeze_box = gs.Node(op="Unsqueeze", inputs=[box_idx_sq, unsq_axes], outputs=[unsq_box_idx]) | |
unsqueeze_class = gs.Node(op="Unsqueeze", inputs=[class_idx_sq, unsq_axes], outputs=[unsq_class_idx]) | |
graph.nodes.extend([unsqueeze_batch, unsqueeze_box, unsqueeze_class]) | |
# Gather boxes: concat batch_idx and box_idx => (num_detections, 2) | |
batch_box_indices = gs.Variable("batch_box_indices", dtype=np.int64) | |
concat_batch_box = gs.Node( | |
op="Concat", | |
inputs=[unsq_batch_idx, unsq_box_idx], | |
outputs=[batch_box_indices], | |
attrs={"axis": 1}, | |
) | |
graph.nodes.append(concat_batch_box) | |
gathered_boxes = gs.Variable("nms_boxes", dtype=np.float32) | |
gather_boxes_node = gs.Node( | |
op="GatherND", | |
inputs=[boxes_var, batch_box_indices], | |
outputs=[gathered_boxes], | |
) | |
graph.nodes.append(gather_boxes_node) | |
# Gather scores: concat batch_idx, box_idx, class_idx => (num_detections, 3) | |
batch_box_class_indices = gs.Variable("batch_box_class_indices", dtype=np.int64) | |
concat_batch_box_class = gs.Node( | |
op="Concat", | |
inputs=[unsq_batch_idx, unsq_box_idx, unsq_class_idx], | |
outputs=[batch_box_class_indices], | |
attrs={"axis": 1}, | |
) | |
graph.nodes.append(concat_batch_box_class) | |
gathered_scores = gs.Variable("nms_scores", dtype=np.float32) | |
gather_scores_node = gs.Node( | |
op="GatherND", | |
inputs=[scores_var, batch_box_class_indices], | |
outputs=[gathered_scores], | |
) | |
graph.nodes.append(gather_scores_node) | |
# Final outputs: batch indices, boxes, scores, classes (all 1D or 2D) | |
graph.outputs.clear() | |
graph.outputs.extend([ | |
batch_idx_sq, # (num_detections,) batch indices | |
gathered_boxes, # (num_detections, 4) boxes | |
gathered_scores, # (num_detections,) scores | |
class_idx_sq # (num_detections,) classes | |
]) | |
graph.cleanup().toposort() | |
return gs.export_onnx(graph) | |
def add_v5_split(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
import onnx_graphsurgeon as gs | |
import numpy as np | |
graph = gs.import_onnx(onnx_model) | |
original_output = graph.outputs[0] | |
graph.outputs = [] | |
output = original_output # shape: (1, 25200, 17) | |
new_nodes = [] | |
def slice_op(name, starts, ends): | |
node = gs.Node( | |
op="Slice", | |
name=name, | |
inputs=[ | |
output, | |
gs.Constant(name=f"{name}_start", values=np.array(starts, dtype=np.int64)), | |
gs.Constant(name=f"{name}_end", values=np.array(ends, dtype=np.int64)), | |
gs.Constant(name=f"{name}_axes", values=np.array([0, 1, 2], dtype=np.int64)), | |
gs.Constant(name=f"{name}_steps", values=np.array([1, 1, 1], dtype=np.int64)) | |
], | |
outputs=[gs.Variable(name=f"{name}_out", dtype=np.float32)] | |
) | |
new_nodes.append(node) | |
return node.outputs[0] | |
boxes_xywh = slice_op("slice_boxes", [0, 0, 0], [1, 25200, 4]) | |
objectness = slice_op("slice_obj", [0, 0, 4], [1, 25200, 5]) | |
class_scores = slice_op("slice_classes", [0, 0, 5], [1, 25200, 17]) | |
x = slice_op("slice_x", [0, 0, 0], [1, 25200, 1]) | |
y = slice_op("slice_y", [0, 0, 1], [1, 25200, 2]) | |
w = slice_op("slice_w", [0, 0, 2], [1, 25200, 3]) | |
h = slice_op("slice_h", [0, 0, 3], [1, 25200, 4]) | |
def half(name, val): | |
node = gs.Node( | |
op="Div", | |
name=f"div2_{name}", | |
inputs=[val, gs.Constant(name=f"two_{name}", values=np.array([2.0], dtype=np.float32))], | |
outputs=[gs.Variable(name=f"{name}_half", dtype=np.float32)] | |
) | |
new_nodes.append(node) | |
return node.outputs[0] | |
hw = half("w", w) | |
hh = half("h", h) | |
def arith(op, name, a, b): | |
node = gs.Node( | |
op=op, | |
name=name, | |
inputs=[a, b], | |
outputs=[gs.Variable(name=f"{name}_out", dtype=np.float32)] | |
) | |
new_nodes.append(node) | |
return node.outputs[0] | |
x1 = arith("Sub", "x1", x, hw) | |
y1 = arith("Sub", "y1", y, hh) | |
x2 = arith("Add", "x2", x, hw) | |
y2 = arith("Add", "y2", y, hh) | |
concat_boxes = gs.Node( | |
op="Concat", | |
name="concat_boxes_xyxy", | |
inputs=[x1, y1, x2, y2], | |
outputs=[gs.Variable(name="boxes_xyxy", dtype=np.float32)], | |
attrs={"axis": 2} | |
) | |
new_nodes.append(concat_boxes) | |
boxes_xyxy = concat_boxes.outputs[0] | |
mul_scores = gs.Node( | |
op="Mul", | |
name="final_scores", | |
inputs=[objectness, class_scores], | |
outputs=[gs.Variable(name="scores", dtype=np.float32)] | |
) | |
new_nodes.append(mul_scores) | |
scores = mul_scores.outputs[0] | |
# Add new nodes to the graph and set outputs | |
graph.nodes.extend(new_nodes) | |
graph.outputs = [boxes_xyxy, scores] | |
graph.cleanup().toposort() | |
return gs.export_onnx(graph) | |
def verify_model_output(onnx_model_path): | |
session_input = np.random.random((4, 3, 640, 640)).astype(np.float32) | |
ort_session = onnxruntime.InferenceSession( | |
onnx_model_path, providers=["CUDAExecutionProvider"] | |
) | |
output = ort_session.run( | |
None, | |
{ | |
input_arg.name: session_input | |
for input_arg, input_value in zip(ort_session.get_inputs(), session_input) | |
}, | |
) | |
output_dict = {out[0]: out[1] for out in zip([o.name for o in ort_session.get_outputs()], output)} | |
for output_name, output_value in output_dict.items(): | |
print(f"{output_name}: {output_value.shape}") | |
def export_model(onnx_model: onnx.ModelProto, output_filepath): | |
onnx_model.ir_version = 10 | |
onnx.save(onnx_model, output_filepath) | |
if __name__ == "__main__": | |
yolo_base_filepath = "models/exported/yolon.onnx" | |
verify_model_output(yolo_base_filepath) | |
yolo_base_model = onnx.load(yolo_base_filepath) | |
yolo_nms_model = add_v5_split(yolo_base_model) | |
yolo_nms_model = add_nms(yolo_nms_model) | |
yolo_nms_filepath = "models/exported/yolon-with_nms.onnx" | |
export_model(yolo_nms_model, yolo_nms_filepath) | |
verify_model_output(yolo_nms_filepath) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment