Created
September 1, 2016 17:49
-
-
Save gibiansky/407340dc25348c1d38e21d177b081397 to your computer and use it in GitHub Desktop.
Demo of wrapping Tensorflow ops with some timing info
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import tensorflow as tf | |
def print_time(name): | |
"""This creates a new function that prints out the time annotated with a | |
marker. | |
>>> print_time("Name")(1, 2, 3) | |
Name 1472751445.5185795 | |
(1, 2, 3) | |
Meant to be used with tf.py_func to add timing information to a graph in | |
this demo. | |
See py_func here: | |
https://www.tensorflow.org/versions/r0.7/api_docs/python/script_ops.html#py_func | |
""" | |
def f(*args): | |
print(name, time.time()) | |
return args | |
return f | |
class TimerGraph(tf.Graph): | |
"""Create a tf.Graph subclass which inserts nodes for timing. These nodes | |
simply print out the time of completion for a few chosen operation types. | |
This idea can be extended (probably?) to adding a custom timing op to | |
record start AND end times to some sort of in-memory database, and then | |
printing it out at the end. | |
Perhaps this would also be the right place to add "performance counters", | |
and effectively have a large mapping between op type and performance | |
metrics (floating point operations, bytes transferred, etc), which then | |
gets populated with data for the model in question in this subclass. | |
""" | |
def create_op(self, *args, **kwargs): | |
"""This method gets called whenever a new op is created; it also adds | |
the op to the graph.""" | |
original_op = super(TimerGraph, self).create_op(*args, **kwargs) | |
op_name = args[0] | |
if op_name in ["Pow", "Add"]: | |
# Our addition just passes through all data, so output dtypes | |
# are just copied from output dtypes of original op | |
output_dtypes = [output.dtype for output in original_op.outputs] | |
# Create a new op by calling a python function, which just passes | |
# along all its inputs as outputs. It's a multi-argument identity | |
# function that records some metadata. By having this be an op we | |
# ensure that it gets called at the right time in the graph. | |
wrap_op = tf.py_func(print_time(op_name), original_op.outputs, | |
output_dtypes) | |
# py_func returns a list of tensors that were output by the | |
# function. We don't actually care about the tensors that were | |
# input, we care about the op, so we can return the op. So we get | |
# the first tensor and get its op, which is the py_func op we care | |
# about. | |
return wrap_op[0].op | |
else: | |
# Just return the original op if this isn't an op we want to time | |
# For example, we don't want to time Placeholder ops... | |
return original_op | |
graph = TimerGraph() | |
with graph.as_default(): | |
x = tf.placeholder(tf.float32) | |
y = tf.placeholder(tf.float32) | |
z = x + y + y + y + y + y | |
w = z ** 2 | |
with tf.Session() as session: | |
print(session.run(w, feed_dict={ | |
x: [i for i in range(10000000)], | |
y: [i for i in range(10000000)] | |
})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment