Skip to content

Instantly share code, notes, and snippets.

@rizar
Created November 30, 2016 15:12
Show Gist options
  • Save rizar/f4556741c2c79c4adf4f47c6ed1cefd7 to your computer and use it in GitHub Desktop.
Save rizar/f4556741c2c79c4adf4f47c6ed1cefd7 to your computer and use it in GitHub Desktop.
Huge overhead of Tensorflow profiling
import tensorflow as tf
import numpy
import time
# The computation graph (in fact just an LSTM language model)
batch_size = 100
vocab_size = 50000
dim = 512
inputs = tf.placeholder(tf.int32, [batch_size, None],
name='inputs')
input_lengths = tf.placeholder(tf.int32, [batch_size],
name='input_lengths')
num_steps = tf.shape(inputs)[1] - 1
embeddings = tf.get_variable("embeddings", [vocab_size, dim])
embedded_inputs = tf.nn.embedding_lookup(embeddings, inputs[:, :-1])
targets = inputs[:, 1:]
cell = tf.nn.rnn_cell.BasicLSTMCell(dim, forget_bias=0.0)
(outputs, state) = tf.nn.dynamic_rnn(cell, embedded_inputs,
dtype=tf.float32)
softmax_w = tf.get_variable("softmax_w", [dim, vocab_size])
softmax_b = tf.get_variable("softmax_b", [vocab_size])
outputs = tf.reshape(outputs, (-1, dim))
logits = tf.nn.log_softmax(tf.matmul(outputs, softmax_w) + softmax_b)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits
losses = cross_entropy(logits, tf.reshape(targets, (-1,)))
losses = tf.reshape(losses, tf.shape(targets))
weights = tf.to_float(tf.less(tf.range(num_steps)[None, :],
input_lengths[:, None] - 1))
losses *= weights
losses_grads = tf.gradients(tf.reduce_sum(losses), tf.trainable_variables())
sess = tf.Session()
sess.run(tf.initialize_all_variables())
feed_dict = {inputs: numpy.ones((batch_size, 100)),
input_lengths: 100 * numpy.ones((batch_size,))}
# Without profiling
for i in range(10):
before = time.time()
sess.run([losses_grads], feed_dict=feed_dict)
print time.time() - before
# With profiling
run_metadata = tf.RunMetadata()
for i in range(10):
before = time.time()
sess.run([losses_grads],
feed_dict=feed_dict,
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
print time.time() - before
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment