Forked from nikashitsa/how_to_remove_dropout_from_frozen_model.ipynb
Created
September 5, 2018 08:37
-
-
Save TKNgu/09bfb982b7516467aaed48585f4acc2e to your computer and use it in GitHub Desktop.
How to remove dropout from frozen model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "How to remove dropout from frozen model\n===" | |
}, | |
{ | |
"metadata": { | |
"collapsed": false, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "from __future__ import print_function\nfrom tensorflow.core.framework import graph_pb2\nimport tensorflow as tf\nimport numpy as np\nfrom tensorflow.examples.tutorials.mnist import input_data\n\nmnist = input_data.read_data_sets('/tmp/data/', one_hot=True)\n\ndef display_nodes(nodes):\n for i, node in enumerate(nodes):\n print('%d %s %s' % (i, node.name, node.op))\n [print(u'└─── %d ─ %s' % (i, n)) for i, n in enumerate(node.input)]\n \ndef accuracy(predictions, labels):\n return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1)) / predictions.shape[0])\n\ndef test_graph(graph_path, use_dropout):\n tf.reset_default_graph()\n graph_def = tf.GraphDef()\n \n with tf.gfile.FastGFile(graph_path, 'rb') as f:\n graph_def.ParseFromString(f.read())\n \n _ = tf.import_graph_def(graph_def, name='')\n sess = tf.Session() \n prediction_tensor = sess.graph.get_tensor_by_name('final_result:0') \n \n feed_dict = {'input:0': mnist.test.images[:256]}\n if use_dropout:\n feed_dict['keep_prob:0'] = 1.0\n \n predictions = sess.run(prediction_tensor, feed_dict)\n result = accuracy(predictions, mnist.test.labels[:256])\n return result", | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": "Extracting /tmp/data/train-images-idx3-ubyte.gz\nExtracting /tmp/data/train-labels-idx1-ubyte.gz\nExtracting /tmp/data/t10k-images-idx3-ubyte.gz\nExtracting /tmp/data/t10k-labels-idx1-ubyte.gz\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"collapsed": false, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# read frozen graph and display nodes\ngraph = tf.GraphDef()\nwith tf.gfile.Open('./frozen_model.pb', 'r') as f:\n data = f.read()\n graph.ParseFromString(data)\n \ndisplay_nodes(graph.node)", | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": "0 input Placeholder\n1 keep_prob Placeholder\n2 Variable Const\n3 Variable/read Identity\n└─── 0 ─ Variable\n4 Variable_1 Const\n5 Variable_1/read Identity\n└─── 0 ─ Variable_1\n6 Variable_2 Const\n7 Variable_2/read Identity\n└─── 0 ─ Variable_2\n8 Variable_3 Const\n9 Variable_3/read Identity\n└─── 0 ─ Variable_3\n10 Variable_4 Const\n11 Variable_4/read Identity\n└─── 0 ─ Variable_4\n12 Variable_5 Const\n13 Variable_5/read Identity\n└─── 0 ─ Variable_5\n14 Variable_6 Const\n15 Variable_6/read Identity\n└─── 0 ─ Variable_6\n16 Variable_7 Const\n17 Variable_7/read Identity\n└─── 0 ─ Variable_7\n18 Reshape/shape Const\n19 Reshape Reshape\n└─── 0 ─ input\n└─── 1 ─ Reshape/shape\n20 Conv2D Conv2D\n└─── 0 ─ Reshape\n└─── 1 ─ Variable/read\n21 BiasAdd BiasAdd\n└─── 0 ─ Conv2D\n└─── 1 ─ Variable_4/read\n22 Relu Relu\n└─── 0 ─ BiasAdd\n23 MaxPool MaxPool\n└─── 0 ─ Relu\n24 Conv2D_1 Conv2D\n└─── 0 ─ MaxPool\n└─── 1 ─ Variable_1/read\n25 BiasAdd_1 BiasAdd\n└─── 0 ─ Conv2D_1\n└─── 1 ─ Variable_5/read\n26 Relu_1 Relu\n└─── 0 ─ BiasAdd_1\n27 MaxPool_1 MaxPool\n└─── 0 ─ Relu_1\n28 Reshape_1/shape Const\n29 Reshape_1 Reshape\n└─── 0 ─ MaxPool_1\n└─── 1 ─ Reshape_1/shape\n30 MatMul MatMul\n└─── 0 ─ Reshape_1\n└─── 1 ─ Variable_2/read\n31 Add Add\n└─── 0 ─ MatMul\n└─── 1 ─ Variable_6/read\n32 Relu_2 Relu\n└─── 0 ─ Add\n33 dropout/Shape Shape\n└─── 0 ─ Relu_2\n34 dropout/random_uniform/min Const\n35 dropout/random_uniform/max Const\n36 dropout/random_uniform/RandomUniform RandomUniform\n└─── 0 ─ dropout/Shape\n37 dropout/random_uniform/sub Sub\n└─── 0 ─ dropout/random_uniform/max\n└─── 1 ─ dropout/random_uniform/min\n38 dropout/random_uniform/mul Mul\n└─── 0 ─ dropout/random_uniform/RandomUniform\n└─── 1 ─ dropout/random_uniform/sub\n39 dropout/random_uniform Add\n└─── 0 ─ dropout/random_uniform/mul\n└─── 1 ─ dropout/random_uniform/min\n40 dropout/add Add\n└─── 0 ─ keep_prob\n└─── 1 ─ dropout/random_uniform\n41 dropout/Floor Floor\n└─── 0 ─ dropout/add\n42 dropout/Div Div\n└─── 0 ─ Relu_2\n└─── 1 ─ keep_prob\n43 dropout/mul Mul\n└─── 0 ─ dropout/Div\n└─── 1 ─ dropout/Floor\n44 MatMul_1 MatMul\n└─── 0 ─ dropout/mul\n└─── 1 ─ Variable_3/read\n45 Add_1 Add\n└─── 0 ─ MatMul_1\n└─── 1 ─ Variable_7/read\n46 final_result Softmax\n└─── 0 ─ Add_1\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# Connect 'MatMul_1' with 'Relu_2'\ngraph.node[44].input[0] = 'Relu_2' # 44 -> MatMul_1\n# Remove dropout nodes\nnodes = graph.node[:33] + graph.node[44:] # 33 -> MatMul_1 \ndel nodes[1] # 1 -> keep_prob\n\n# Save graph\noutput_graph = graph_pb2.GraphDef()\noutput_graph.node.extend(nodes)\nwith tf.gfile.GFile('./frozen_model_without_dropout.pb', 'w') as f:\n f.write(output_graph.SerializeToString())", | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"collapsed": false, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# test graph via simple test\nresult_1 = test_graph('./frozen_model.pb', use_dropout=True)\nresult_2 = test_graph('./frozen_model_without_dropout.pb', use_dropout=False)\n\nprint('with dropout: %f' % result_1)\nprint('without dropout: %f' % result_2)", | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": "with dropout: 80.859375\nwithout dropout: 80.859375\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "python2", | |
"display_name": "Python 2", | |
"language": "python" | |
}, | |
"language_info": { | |
"mimetype": "text/x-python", | |
"nbconvert_exporter": "python", | |
"name": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.12", | |
"file_extension": ".py", | |
"codemirror_mode": { | |
"version": 2, | |
"name": "ipython" | |
} | |
}, | |
"gist": { | |
"id": "4498bb2174d85104c4396d3f48a0a09d", | |
"data": { | |
"description": "How to remove dropout from frozen model", | |
"public": true | |
} | |
}, | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/4498bb2174d85104c4396d3f48a0a09d" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment