Skip to content

Instantly share code, notes, and snippets.

@DibyaranjanSathua
Created February 28, 2019 11:42
Show Gist options
  • Save DibyaranjanSathua/383b4cbf056aa9e899a2d85cb3a24f54 to your computer and use it in GitHub Desktop.
Save DibyaranjanSathua/383b4cbf056aa9e899a2d85cb3a24f54 to your computer and use it in GitHub Desktop.
Remove dropout block.
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
import copy
def remove_dropout(input_graph, output_graph):
""" Remove the dropout block from the model. """
graph_def = load_graph(input_graph)
new_graph_def = graph_pb2.GraphDef()
for node in graph_def.node:
modified_node = copy.deepcopy(node)
if node.name.startswith('dropout1') or node.name.startswith('dropout2'):
continue
if node.name == "fc2/fc2/batch_norm/batchnorm/mul_1":
modified_node.input[0] = "mul"
modified_node.input[1] = "fc2/weights"
if node.name == "logits/logits/batch_norm/batchnorm/mul_1":
modified_node.input[0] = "fc2/activation"
modified_node.input[1] = "logits/weights"
new_graph_def.node.extend([modified_node])
# save the graph
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(new_graph_def.SerializeToString())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment