Skip to content

Instantly share code, notes, and snippets.

@annarailton
Created July 9, 2018 09:13
Show Gist options
  • Save annarailton/f48cac170c9e80edc26a6cccc2ecb3ac to your computer and use it in GitHub Desktop.
Save annarailton/f48cac170c9e80edc26a6cccc2ecb3ac to your computer and use it in GitHub Desktop.
"""Simple example of various types of Tensorflow profiling.
Adapted from:
https://towardsdatascience.com/howto-profile-tensorflow-1a49fb18073d
https://www.tensorflow.org/api_docs/python/tf/profiler/Profiler
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import json
import tensorflow as tf
from tensorflow.python.client import timeline
from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None
class TimeLiner(object):
_timeline_dict = None
def update_timeline(self, chrome_trace):
# convert crome trace to python dict
chrome_trace_dict = json.loads(chrome_trace)
# for first run store full trace
if self._timeline_dict is None:
self._timeline_dict = chrome_trace_dict
# for other - update only time consumption, not definitions
else:
for event in chrome_trace_dict['traceEvents']:
# events time consumption started with 'ts' prefix
if 'ts' in event:
self._timeline_dict['traceEvents'].append(event)
def save(self, f_name):
with open(f_name, 'w') as f:
json.dump(self._timeline_dict, f)
def main(_):
# Import data
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
timeliner = TimeLiner()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
profiler = tf.profiler.Profiler(sess.graph)
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
option_builder = tf.profiler.ProfileOptionBuilder
# Train
n_steps = FLAGS.steps
for i in range(n_steps):
batch_xs, batch_ys = mnist.train.next_batch(100)
run_metadata = tf.RunMetadata()
sess.run(
train_step,
options=options,
run_metadata=run_metadata,
feed_dict={
x: batch_xs,
y_: batch_ys
})
# We collect profiling infos for each step.
profiler.add_step(i, run_metadata)
# Generate a timeline for each step
fetched_timeline = timeline.Timeline(run_metadata.step_stats)
chrome_trace = fetched_timeline.generate_chrome_trace_format()
timeliner.update_timeline(chrome_trace)
timeliner.save('test-timeline.txt')
# Test trained model
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
run_metadata = tf.RunMetadata()
print(
sess.run(
accuracy,
options=options,
run_metadata=run_metadata,
feed_dict={
x: mnist.test.images,
y_: mnist.test.labels
}))
# Collect profiling infos for last step
profiler.add_step(n_steps, run_metadata)
# Profile the timing of your model operations.
# Profiling infos about ops are saved in 'test-%s.txt' % FLAGS.out
opts = (
option_builder(option_builder.time_and_memory()).with_step(-1)
. # with -1, should compute the average of all registered steps.
with_file_output('test-%s.txt' % FLAGS.out).select(
['micros', 'bytes', 'occurrence']).order_by('micros').build())
profiler.profile_operations(options=opts)
# Profile the parameters
profiler.profile_name_scope(
options=(option_builder.trainable_variables_parameter()))
# Auto detect problems and generate advice.
profiler.advise(
options={
'ExpensiveOperationChecker': {},
'AcceleratorUtilizationChecker': {},
'OperationChecker': {},
})
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir',
type=str,
default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
parser.add_argument(
'--steps', type=int, default=10, help='Number of steps to run.')
parser.add_argument(
'--out', type=str, default='profiling', help='Output filename.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment