Created
August 26, 2018 15:10
-
-
Save jaedeokk/926ce36dc162c845063749f770b26e85 to your computer and use it in GitHub Desktop.
Counting FLOPs of a Keras model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Code Snippet for counting FLOPs of a model. | |
Not final version, it will be updated to improve the usability. | |
""" | |
import os.path | |
import tempfile | |
import tensorflow as tf | |
from tensorflow.python.keras import Model, Sequential | |
def count_flops(model): | |
""" Count flops of a keras model | |
# Args. | |
model: Model, | |
# Returns | |
int, FLOPs of a model | |
# Raises | |
TypeError, if a model is not an instance of Sequence or Model | |
""" | |
if not isinstance(model, (Sequential, Model)): | |
raise TypeError( | |
'Model is expected to be an instance of Sequential or Model, ' | |
'but got %s' % type(model)) | |
output_op_names = [_out_tensor.op.name for _out_tensor in model.outputs] | |
sess = tf.keras.backend.get_session() | |
frozen_graph_def = tf.graph_util.convert_variables_to_constants( | |
sess, sess.graph.as_graph_def(), output_op_names) | |
with tempfile.TemporaryDirectory() as tmpdir: | |
graph_file = os.path.join(os.path.join(tmpdir, 'graph.pb')) | |
with tf.gfile.GFile(graph_file, "wb") as f: | |
f.write(frozen_graph_def.SerializeToString()) | |
with tf.gfile.GFile(graph_file, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
with tf.Graph().as_default() as new_graph: | |
tf.import_graph_def(graph_def, name='') | |
tfprof_opts = tf.profiler.ProfileOptionBuilder.float_operation() | |
flops = tf.profiler.profile(new_graph, options=tfprof_opts) | |
writer = tf.summary.FileWriter('gg', graph=new_graph) | |
writer.flush() | |
return flops | |
if __name__ == '__main__': | |
vgg = tf.keras.applications.vgg16.VGG16( | |
include_top=True, weights=None, | |
input_tensor=tf.keras.Input(batch_shape=(1, 224, 224, 3))) | |
flops = count_flops(vgg) | |
print(flops) |
@maguscl I'm happy to hear that :) And thanks for pointing out this version compatibility issue of this gist. This gist was written on tf-1.6.
But, TF keeps changing too fast and thus I have no plan to update this gist to fit in the current TF version. Please refer this gist as just an old sample :)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi!
Line 31 must be: sess = tf.compat.v1.keras.backend.get_session()
Thxs for sharing the code. Very useful