Skip to content

Instantly share code, notes, and snippets.

@josepsmartinez
Created July 29, 2025 20:29
Show Gist options
  • Save josepsmartinez/361214d8fa28999d6995a5c07b39e866 to your computer and use it in GitHub Desktop.
Save josepsmartinez/361214d8fa28999d6995a5c07b39e866 to your computer and use it in GitHub Desktop.
Custom export of YOLOv5 to ONNX with NMS
import onnxruntime
import numpy as np
import onnx
from onnx import helper, numpy_helper
import onnx_graphsurgeon as gs
def add_nms(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
graph = gs.import_onnx(onnx_model)
boxes_var = graph.outputs[0] # (B, N, 4)
scores_var = graph.outputs[1] # (B, N, C)
# Set dynamic shape symbols (optional but recommended)
batch = "batch_size"
num_boxes = "num_boxes"
num_classes = "num_classes"
boxes_var.shape = [batch, num_boxes, 4]
scores_var.shape = [batch, num_boxes, num_classes]
# Transpose scores to (B, C, N)
scores_transposed = gs.Variable("scores_nms", dtype=np.float32)
transpose_node = gs.Node(
op="Transpose",
inputs=[scores_var],
outputs=[scores_transposed],
attrs={"perm": [0, 2, 1]},
)
graph.nodes.append(transpose_node)
# Constants for NMS
max_output_boxes_per_class = gs.Constant("max_out", np.array([100], dtype=np.int64))
iou_threshold = gs.Constant("iou_thresh", np.array([0.45], dtype=np.float32))
score_threshold = gs.Constant("score_thresh", np.array([0.25], dtype=np.float32))
# NMS node: output shape (num_detections, 3) = [batch_idx, class_idx, box_idx]
nms_out = gs.Variable("nms_indices", dtype=np.int64)
nms_node = gs.Node(
op="NonMaxSuppression",
inputs=[boxes_var, scores_transposed, max_output_boxes_per_class, iou_threshold, score_threshold],
outputs=[nms_out],
)
graph.nodes.append(nms_node)
# Split NMS output into batch_idx, class_idx, box_idx (each shape (num_detections,1))
batch_idx = gs.Variable("batch_idx", dtype=np.int64)
class_idx = gs.Variable("class_idx", dtype=np.int64)
box_idx = gs.Variable("box_idx", dtype=np.int64)
split_sizes = gs.Constant("split_sizes", np.array([1, 1, 1], dtype=np.int64))
split_node = gs.Node(
op="Split",
inputs=[nms_out, split_sizes],
outputs=[batch_idx, class_idx, box_idx],
attrs={"axis": 1},
)
graph.nodes.append(split_node)
# Squeeze to 1D (num_detections,)
squeeze_axes = gs.Constant("squeeze_axes", np.array([1], dtype=np.int64))
batch_idx_sq = gs.Variable("batch_idx_sq", dtype=np.int64)
squeeze_batch = gs.Node(op="Squeeze", inputs=[batch_idx, squeeze_axes], outputs=[batch_idx_sq])
class_idx_sq = gs.Variable("class_idx_sq", dtype=np.int64)
squeeze_class = gs.Node(op="Squeeze", inputs=[class_idx, squeeze_axes], outputs=[class_idx_sq])
box_idx_sq = gs.Variable("box_idx_sq", dtype=np.int64)
squeeze_box = gs.Node(op="Squeeze", inputs=[box_idx, squeeze_axes], outputs=[box_idx_sq])
graph.nodes.extend([squeeze_batch, squeeze_class, squeeze_box])
# Unsqueeze each index vector to (num_detections, 1) for concat
unsq_batch_idx = gs.Variable("unsq_batch_idx", dtype=np.int64)
unsq_box_idx = gs.Variable("unsq_box_idx", dtype=np.int64)
unsq_class_idx = gs.Variable("unsq_class_idx", dtype=np.int64)
unsq_axes = gs.Constant("unsq_axes", values=np.array([1], dtype=np.int64))
unsqueeze_batch = gs.Node(op="Unsqueeze", inputs=[batch_idx_sq, unsq_axes], outputs=[unsq_batch_idx])
unsqueeze_box = gs.Node(op="Unsqueeze", inputs=[box_idx_sq, unsq_axes], outputs=[unsq_box_idx])
unsqueeze_class = gs.Node(op="Unsqueeze", inputs=[class_idx_sq, unsq_axes], outputs=[unsq_class_idx])
graph.nodes.extend([unsqueeze_batch, unsqueeze_box, unsqueeze_class])
# Gather boxes: concat batch_idx and box_idx => (num_detections, 2)
batch_box_indices = gs.Variable("batch_box_indices", dtype=np.int64)
concat_batch_box = gs.Node(
op="Concat",
inputs=[unsq_batch_idx, unsq_box_idx],
outputs=[batch_box_indices],
attrs={"axis": 1},
)
graph.nodes.append(concat_batch_box)
gathered_boxes = gs.Variable("nms_boxes", dtype=np.float32)
gather_boxes_node = gs.Node(
op="GatherND",
inputs=[boxes_var, batch_box_indices],
outputs=[gathered_boxes],
)
graph.nodes.append(gather_boxes_node)
# Gather scores: concat batch_idx, box_idx, class_idx => (num_detections, 3)
batch_box_class_indices = gs.Variable("batch_box_class_indices", dtype=np.int64)
concat_batch_box_class = gs.Node(
op="Concat",
inputs=[unsq_batch_idx, unsq_box_idx, unsq_class_idx],
outputs=[batch_box_class_indices],
attrs={"axis": 1},
)
graph.nodes.append(concat_batch_box_class)
gathered_scores = gs.Variable("nms_scores", dtype=np.float32)
gather_scores_node = gs.Node(
op="GatherND",
inputs=[scores_var, batch_box_class_indices],
outputs=[gathered_scores],
)
graph.nodes.append(gather_scores_node)
# Final outputs: batch indices, boxes, scores, classes (all 1D or 2D)
graph.outputs.clear()
graph.outputs.extend([
batch_idx_sq, # (num_detections,) batch indices
gathered_boxes, # (num_detections, 4) boxes
gathered_scores, # (num_detections,) scores
class_idx_sq # (num_detections,) classes
])
graph.cleanup().toposort()
return gs.export_onnx(graph)
def add_v5_split(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
import onnx_graphsurgeon as gs
import numpy as np
graph = gs.import_onnx(onnx_model)
original_output = graph.outputs[0]
graph.outputs = []
output = original_output # shape: (1, 25200, 17)
new_nodes = []
def slice_op(name, starts, ends):
node = gs.Node(
op="Slice",
name=name,
inputs=[
output,
gs.Constant(name=f"{name}_start", values=np.array(starts, dtype=np.int64)),
gs.Constant(name=f"{name}_end", values=np.array(ends, dtype=np.int64)),
gs.Constant(name=f"{name}_axes", values=np.array([0, 1, 2], dtype=np.int64)),
gs.Constant(name=f"{name}_steps", values=np.array([1, 1, 1], dtype=np.int64))
],
outputs=[gs.Variable(name=f"{name}_out", dtype=np.float32)]
)
new_nodes.append(node)
return node.outputs[0]
boxes_xywh = slice_op("slice_boxes", [0, 0, 0], [1, 25200, 4])
objectness = slice_op("slice_obj", [0, 0, 4], [1, 25200, 5])
class_scores = slice_op("slice_classes", [0, 0, 5], [1, 25200, 17])
x = slice_op("slice_x", [0, 0, 0], [1, 25200, 1])
y = slice_op("slice_y", [0, 0, 1], [1, 25200, 2])
w = slice_op("slice_w", [0, 0, 2], [1, 25200, 3])
h = slice_op("slice_h", [0, 0, 3], [1, 25200, 4])
def half(name, val):
node = gs.Node(
op="Div",
name=f"div2_{name}",
inputs=[val, gs.Constant(name=f"two_{name}", values=np.array([2.0], dtype=np.float32))],
outputs=[gs.Variable(name=f"{name}_half", dtype=np.float32)]
)
new_nodes.append(node)
return node.outputs[0]
hw = half("w", w)
hh = half("h", h)
def arith(op, name, a, b):
node = gs.Node(
op=op,
name=name,
inputs=[a, b],
outputs=[gs.Variable(name=f"{name}_out", dtype=np.float32)]
)
new_nodes.append(node)
return node.outputs[0]
x1 = arith("Sub", "x1", x, hw)
y1 = arith("Sub", "y1", y, hh)
x2 = arith("Add", "x2", x, hw)
y2 = arith("Add", "y2", y, hh)
concat_boxes = gs.Node(
op="Concat",
name="concat_boxes_xyxy",
inputs=[x1, y1, x2, y2],
outputs=[gs.Variable(name="boxes_xyxy", dtype=np.float32)],
attrs={"axis": 2}
)
new_nodes.append(concat_boxes)
boxes_xyxy = concat_boxes.outputs[0]
mul_scores = gs.Node(
op="Mul",
name="final_scores",
inputs=[objectness, class_scores],
outputs=[gs.Variable(name="scores", dtype=np.float32)]
)
new_nodes.append(mul_scores)
scores = mul_scores.outputs[0]
# Add new nodes to the graph and set outputs
graph.nodes.extend(new_nodes)
graph.outputs = [boxes_xyxy, scores]
graph.cleanup().toposort()
return gs.export_onnx(graph)
def verify_model_output(onnx_model_path):
session_input = np.random.random((4, 3, 640, 640)).astype(np.float32)
ort_session = onnxruntime.InferenceSession(
onnx_model_path, providers=["CUDAExecutionProvider"]
)
output = ort_session.run(
None,
{
input_arg.name: session_input
for input_arg, input_value in zip(ort_session.get_inputs(), session_input)
},
)
output_dict = {out[0]: out[1] for out in zip([o.name for o in ort_session.get_outputs()], output)}
for output_name, output_value in output_dict.items():
print(f"{output_name}: {output_value.shape}")
def export_model(onnx_model: onnx.ModelProto, output_filepath):
onnx_model.ir_version = 10
onnx.save(onnx_model, output_filepath)
if __name__ == "__main__":
yolo_base_filepath = "models/exported/yolon.onnx"
verify_model_output(yolo_base_filepath)
yolo_base_model = onnx.load(yolo_base_filepath)
yolo_nms_model = add_v5_split(yolo_base_model)
yolo_nms_model = add_nms(yolo_nms_model)
yolo_nms_filepath = "models/exported/yolon-with_nms.onnx"
export_model(yolo_nms_model, yolo_nms_filepath)
verify_model_output(yolo_nms_filepath)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment