Created
August 2, 2018 10:29
-
-
Save annarailton/4d4eafbf86be756a51e41c6c4d9c0212 to your computer and use it in GitHub Desktop.
Plot two different metrics on the same plot in Tensorboard
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# View graphs with (Linux): $ tensorboard --logdir=/tmp/my_tf_model | |
import os | |
import tempfile | |
import tensorflow as tf | |
import numpy as np | |
from tensorboard import summary as summary_lib | |
from tensorboard.plugins.custom_scalar import layout_pb2 | |
def train_data_gen(): | |
yield np.random.normal(size=[3]), np.array([0.5, 0.5, 0.5]) | |
def valid_data_gen(): | |
yield np.random.normal(size=[3]), np.array([0.8, 0.8, 0.8]) | |
batch_size = 25 | |
n_training_batches = 4 | |
n_valid_batches = 2 | |
n_epochs = 5 | |
summary_loc = os.path.join(tempfile.gettempdir(), 'my_tf_model') | |
print("Summaries written to " + summary_loc) | |
# Dummy data | |
train_data = tf.data.Dataset.from_generator( | |
train_data_gen, (tf.float32, tf.float32)).repeat().batch(batch_size) | |
valid_data = tf.data.Dataset.from_generator( | |
valid_data_gen, (tf.float32, tf.float32)).repeat().batch(batch_size) | |
handle = tf.placeholder(tf.string, shape=[]) | |
iterator = tf.data.Iterator.from_string_handle(handle, train_data.output_types, | |
train_data.output_shapes) | |
batch_x, batch_y = iterator.get_next() | |
train_iter = train_data.make_initializable_iterator() | |
valid_iter = valid_data.make_initializable_iterator() | |
# Some ops on the data | |
loss = tf.losses.mean_squared_error(batch_x, batch_y) | |
valid_loss, valid_loss_update = tf.metrics.mean(loss) | |
with tf.name_scope('loss'): | |
train_summ = summary_lib.scalar('training', loss) | |
valid_summ = summary_lib.scalar('valid', valid_loss) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
train_handle, valid_handle = sess.run( | |
[train_iter.string_handle(), | |
valid_iter.string_handle()]) | |
sess.run([train_iter.initializer, valid_iter.initializer]) | |
writer_train = tf.summary.FileWriter( | |
os.path.join(summary_loc, 'train'), sess.graph) | |
writer_valid = tf.summary.FileWriter( | |
os.path.join(summary_loc, 'valid'), sess.graph) | |
layout_summary = summary_lib.custom_scalar_pb( | |
layout_pb2.Layout(category=[ | |
layout_pb2.Category( | |
title='losses', | |
chart=[ | |
layout_pb2.Chart( | |
title='losses', | |
multiline=layout_pb2.MultilineChartContent( | |
tag=['loss/training', 'loss/valid'])) | |
]) | |
])) | |
writer_train.add_summary(layout_summary) | |
global_step = 0 | |
for i in range(n_epochs): | |
for j in range(n_training_batches): # "Training" | |
global_step += 1 | |
summ = sess.run(train_summ, feed_dict={handle: train_handle}) | |
writer_train.add_summary(summary=summ, global_step=global_step) | |
sess.run(tf.local_variables_initializer()) | |
for j in range(n_valid_batches): # "Validation" | |
_, batch_summ = sess.run( | |
[valid_loss_update, train_summ], | |
feed_dict={handle: valid_handle}) | |
summ = sess.run(valid_summ) | |
writer_valid.add_summary(summary=summ, global_step=global_step) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment