Skip to content

Instantly share code, notes, and snippets.

@thsno02
Created October 9, 2023 02:35
Show Gist options
  • Save thsno02/355aec92f6639ef8e6f770e0a7bc7709 to your computer and use it in GitHub Desktop.
Save thsno02/355aec92f6639ef8e6f770e0a7bc7709 to your computer and use it in GitHub Desktop.
add_normalization_to_onnx
def add_normalization_to_onnx(model_path: str, first_node_name: str, mean: list, std: list):
'''
Edit the exported onnx model => add preprocess layer: sub mean and div std
'''
model = onnx.load(model_path)
onnx.save(model, model_path.replace('inference', 'raw_inference'))
# Assuming the input of your model is a single tensor of shape (batch_size, num_channels, height, width)
input_tensor = model.graph.input[0]
# Create tensors for mean and standard deviation
mean_tensor = helper.make_tensor(
name="mean",
data_type=onnx.TensorProto.FLOAT,
dims=[1, 3, 1, 1],
vals=mean # You may need to adjust this list depending on your mean values per channel
)
std_dev_tensor = helper.make_tensor(
name="std_dev",
data_type=onnx.TensorProto.FLOAT,
dims=[1, 3, 1, 1],
vals=std # You may need to adjust this list depending on your std dev values per channel
)
# Add these tensors to the graph
model.graph.initializer.extend([mean_tensor, std_dev_tensor])
# Add the normalization operations
sub_node = helper.make_node(
"Sub",
inputs=[input_tensor.name, "mean"],
outputs=["sub_output"]
)
div_node = helper.make_node(
"Div",
inputs=["sub_output", "std_dev"],
outputs=["div_output"]
)
model.graph.node.insert(0, sub_node)
model.graph.node.insert(1, div_node)
# Update the input of the next node in the graph to the output of the Div node
idx = 0
for input_node in model.graph.node:
if first_node_name == input_node.name:
break
idx += 1
next_node = model.graph.node[idx]
next_node.input[0] = "div_output"
onnx.save(model, model_path)
return
model_path = ... # your inference model path
first_node_name = ... # the first node name in the inference model
mean = [, , ,] # list
std = [, , ,] # list
add_normalization_to_onnx(model_path, first_node_name, mean, std)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment