Last active
June 22, 2018 09:13
-
-
Save annarailton/a76a963cb61206e172e8a05ccf90a75b to your computer and use it in GitHub Desktop.
Tensorboard summary example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""MWE of a tensorboard setup with | |
- a dummy training and validation op | |
- loss for training and validation shown on the same graph | |
- histograms showing weights changing with training | |
To run tensorboard after running this script | |
$ tensorboard --log_dir=/tmp/tensorboard_demo | |
""" | |
import os | |
import tensorflow as tf | |
def training(): | |
"""Dummy training operation""" | |
loss = tf.truncated_normal([], stddev=0.1, mean=1) | |
weights = tf.truncated_normal((2, 2)) | |
# Add values to collection so they can be accessed elsewhere in the graph | |
tf.add_to_collection("loss", loss) | |
tf.add_to_collection("weights", weights) | |
return loss | |
def validation(): | |
"""Dummy validation operation""" | |
loss = training() | |
return loss | |
model_loc = '/tmp/tensorboard_demo' | |
train_op = training() | |
eval_op = validation() | |
# Add scalar values to summary | |
tf.summary.scalar("loss", tf.get_collection("loss")[0]) | |
# Write out the weights to histograms | |
# (Can see how the weights change as training progresses) | |
tf.summary.histogram("weights", tf.get_collection("weights")[0]) | |
# Make a write operation for all the summaries | |
write_op = tf.summary.merge_all() | |
init = tf.global_variables_initializer() | |
with tf.Session() as sess: | |
sess.run(init) | |
# Make two writers, one for training and one for validation | |
# (This means you get two curves on your loss graph!) | |
writer = {} | |
writer['train'] = tf.summary.FileWriter( | |
os.path.join(model_loc, 'train'), sess.graph) | |
writer['validation'] = tf.summary.FileWriter( | |
os.path.join(model_loc, 'validation'), sess.graph) | |
global_step = 0 | |
for n_epoch in range(10): | |
# Dummy training | |
for n_batches in range(25): | |
global_step += 1 # NB global_step would be a tf.Variable in a real model | |
_, summary = sess.run([train_op, write_op]) | |
writer['train'].add_summary( | |
summary, | |
global_step=global_step) # Add summary value to the writer | |
# Dummy validation | |
_, summary = sess.run([eval_op, write_op]) | |
writer['validation'].add_summary(summary, global_step=global_step) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment