Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Forked from vihari/tf_print.py
Last active December 12, 2018 12:47
Show Gist options
  • Save razhangwei/12db1b20fe1c1c6065bfbf7cfbe38474 to your computer and use it in GitHub Desktop.
Save razhangwei/12db1b20fe1c1c6065bfbf7cfbe38474 to your computer and use it in GitHub Desktop.
Tensorflow's tf.Print to stdout instead of default stderr #TensorFlow
"""
The default tf.Print op goes to STDERR
Use the function below to direct the output to stdout instead
Usage:
> x=tf.ones([1, 2])
> y=tf.zeros([1, 3])
> p = x*x
> p = tf_print(p, [x, y], "hello")
> p.eval()
hello [[ 0. 0.]]
hello [[ 1. 1.]]
"""
def tf_print(op, tensors, message=None):
def print_message(x):
sys.stdout.write(message + " %s\n" % x)
return x
prints = [tf.py_func(print_message, [tensor], tensor.dtype) for tensor in tensors]
with tf.control_dependencies(prints):
op = tf.identity(op)
return op
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment