Created
July 9, 2018 09:13
-
-
Save annarailton/f48cac170c9e80edc26a6cccc2ecb3ac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Simple example of various types of Tensorflow profiling. | |
Adapted from: | |
https://towardsdatascience.com/howto-profile-tensorflow-1a49fb18073d | |
https://www.tensorflow.org/api_docs/python/tf/profiler/Profiler | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import sys | |
import json | |
import tensorflow as tf | |
from tensorflow.python.client import timeline | |
from tensorflow.examples.tutorials.mnist import input_data | |
FLAGS = None | |
class TimeLiner(object): | |
_timeline_dict = None | |
def update_timeline(self, chrome_trace): | |
# convert crome trace to python dict | |
chrome_trace_dict = json.loads(chrome_trace) | |
# for first run store full trace | |
if self._timeline_dict is None: | |
self._timeline_dict = chrome_trace_dict | |
# for other - update only time consumption, not definitions | |
else: | |
for event in chrome_trace_dict['traceEvents']: | |
# events time consumption started with 'ts' prefix | |
if 'ts' in event: | |
self._timeline_dict['traceEvents'].append(event) | |
def save(self, f_name): | |
with open(f_name, 'w') as f: | |
json.dump(self._timeline_dict, f) | |
def main(_): | |
# Import data | |
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) | |
# Create the model | |
x = tf.placeholder(tf.float32, [None, 784]) | |
W = tf.Variable(tf.zeros([784, 10])) | |
b = tf.Variable(tf.zeros([10])) | |
y = tf.matmul(x, W) + b | |
# Define loss and optimizer | |
y_ = tf.placeholder(tf.float32, [None, 10]) | |
cross_entropy = tf.reduce_mean( | |
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) | |
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) | |
timeliner = TimeLiner() | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
profiler = tf.profiler.Profiler(sess.graph) | |
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) | |
option_builder = tf.profiler.ProfileOptionBuilder | |
# Train | |
n_steps = FLAGS.steps | |
for i in range(n_steps): | |
batch_xs, batch_ys = mnist.train.next_batch(100) | |
run_metadata = tf.RunMetadata() | |
sess.run( | |
train_step, | |
options=options, | |
run_metadata=run_metadata, | |
feed_dict={ | |
x: batch_xs, | |
y_: batch_ys | |
}) | |
# We collect profiling infos for each step. | |
profiler.add_step(i, run_metadata) | |
# Generate a timeline for each step | |
fetched_timeline = timeline.Timeline(run_metadata.step_stats) | |
chrome_trace = fetched_timeline.generate_chrome_trace_format() | |
timeliner.update_timeline(chrome_trace) | |
timeliner.save('test-timeline.txt') | |
# Test trained model | |
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
run_metadata = tf.RunMetadata() | |
print( | |
sess.run( | |
accuracy, | |
options=options, | |
run_metadata=run_metadata, | |
feed_dict={ | |
x: mnist.test.images, | |
y_: mnist.test.labels | |
})) | |
# Collect profiling infos for last step | |
profiler.add_step(n_steps, run_metadata) | |
# Profile the timing of your model operations. | |
# Profiling infos about ops are saved in 'test-%s.txt' % FLAGS.out | |
opts = ( | |
option_builder(option_builder.time_and_memory()).with_step(-1) | |
. # with -1, should compute the average of all registered steps. | |
with_file_output('test-%s.txt' % FLAGS.out).select( | |
['micros', 'bytes', 'occurrence']).order_by('micros').build()) | |
profiler.profile_operations(options=opts) | |
# Profile the parameters | |
profiler.profile_name_scope( | |
options=(option_builder.trainable_variables_parameter())) | |
# Auto detect problems and generate advice. | |
profiler.advise( | |
options={ | |
'ExpensiveOperationChecker': {}, | |
'AcceleratorUtilizationChecker': {}, | |
'OperationChecker': {}, | |
}) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--data_dir', | |
type=str, | |
default='/tmp/tensorflow/mnist/input_data', | |
help='Directory for storing input data') | |
parser.add_argument( | |
'--steps', type=int, default=10, help='Number of steps to run.') | |
parser.add_argument( | |
'--out', type=str, default='profiling', help='Output filename.') | |
FLAGS, unparsed = parser.parse_known_args() | |
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment