Skip to content

Instantly share code, notes, and snippets.

@ibaiGorordo
Created April 2, 2022 07:38
Show Gist options
  • Save ibaiGorordo/000df2910911b6d62492b98e3b4ccdba to your computer and use it in GitHub Desktop.
Save ibaiGorordo/000df2910911b6d62492b98e3b4ccdba to your computer and use it in GitHub Desktop.
Python script for creating a model that concatenates two inputs along the channel axis using ONNX graph surgeon
import onnx_graphsurgeon as gs
import numpy as np
import onnx
left_rect_img = gs.Variable(name="left_rect", dtype=np.float32, shape=(1, 3, 240, 320))
right_rect_img = gs.Variable(name="right_rect", dtype=np.float32, shape=(1, 3, 240, 320))
concat_img = gs.Variable(name="concat_img", dtype=np.float32, shape=(1, 6, 240, 320))
concat_node = gs.Node(op="Concat", attrs={"axis": 1}, inputs=[left_rect_img, right_rect_img], outputs=[concat_img])
graph = gs.Graph(nodes=[concat_node], inputs=[left_rect_img, right_rect_img], outputs=[concat_img])
onnx.save(gs.export_onnx(graph), "test_concat.onnx")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment