Skip to content

Instantly share code, notes, and snippets.

@DibyaranjanSathua
Created February 28, 2019 10:42
Show Gist options
  • Save DibyaranjanSathua/2a2e4a5d630ec4864874e83db2879019 to your computer and use it in GitHub Desktop.
Save DibyaranjanSathua/2a2e4a5d630ec4864874e83db2879019 to your computer and use it in GitHub Desktop.
Optimize the batch normalization block by removing the training nodes.
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
import copy
def optimize_batch_normalization(input_graph, output_graph):
""" Optimize the batch normalization block. """
graph_def = load_graph(input_graph) # Defined above
new_graph_def = graph_pb2.GraphDef()
unused_attrs = ['is_training'] # Attributes of FusedBatchNorm. Not needed during inference.
# All the node names are specific to my ocr model.
# All the input names are found manually from tensorboard
for node in graph_def.node:
modified_node = copy.deepcopy(node)
if node.name.startswith("conv"): # True for Convolutional Layers
starting_name = ""
if node.name.startswith("conv1"):
starting_name = "conv1"
elif node.name.startswith("conv2"):
starting_name = "conv2"
elif node.name.startswith("conv3"):
starting_name = "conv3"
elif node.name.startswith("conv4"):
starting_name = "conv4"
# Do not add the cond block and its child nodes.
# This is only needed during training.
if "cond" in node.name and not node.name.endswith("FusedBatchNorm"):
continue
if node.op == "FusedBatchNorm" and node.name.endswith("FusedBatchNorm"):
if bool(starting_name):
# Changing the name to remove one block hierarchy and changing inputs.
modified_node.name = "{0}/{0}/batch_norm/FusedBatchNorm".format(starting_name)
modified_node.input[0] = "{}/Conv2D".format(starting_name)
modified_node.input[1] = "{}/batch_norm/gamma".format(starting_name)
modified_node.input[2] = "{}/batch_norm/beta".format(starting_name)
modified_node.input[3] = "{}/batch_norm/moving_mean".format(starting_name)
modified_node.input[4] = "{}/batch_norm/moving_variance".format(starting_name)
# Deleting unused attributes
for attr in unused_attrs:
if attr in modified_node.attr:
del modified_node.attr[attr]
if node.name.endswith('activation'):
if bool(starting_name):
modified_node.input[0] = "{0}/{0}/batch_norm/FusedBatchNorm".format(starting_name)
elif node.name.startswith("fc") or node.name.startswith("logits"): # True for fully connected layers
starting_name = ""
if node.name.startswith("fc1"):
starting_name = "fc1"
elif node.name.startswith("fc2"):
starting_name = "fc2"
elif node.name.startswith("logits"):
starting_name = "logits"
# Do not add cond, cond_1 and moments block of batch normalization
if "cond" in node.name or "moments" in node.name:
continue
# Change input of batchnorm/add
if node.name.endswith('batchnorm/add'):
modified_node.input[0] = "{}/batch_norm/moving_variance".format(starting_name)
modified_node.input[1] = "{0}/{0}/batch_norm/batchnorm/add/y".format(starting_name)
if node.name.endswith('batchnorm/mul_2'):
modified_node.input[0] = "{0}/{0}/batch_norm/batchnorm/mul".format(starting_name)
modified_node.input[1] = "{}/batch_norm/moving_mean".format(starting_name)
new_graph_def.node.extend([modified_node])
# save the graph
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(new_graph_def.SerializeToString())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment