Created
November 17, 2017 19:55
-
-
Save kkleidal/c88e033193edf92d4027943e49b27d96 to your computer and use it in GitHub Desktop.
Matplotlib Image Summaries in Tensorboard
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import io | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy as scipy # Ensure PIL is also installed: pip install pillow | |
''' | |
matplotlib_summary code: | |
Code for generating a tensorflow image summary of a custom matplotlib plot. | |
Usage: matplotlib_summary(plotting_function, argument1, argument2, ..., name="summary name") | |
plotting_function is a function which take the matplotlib figure as the first argument and numpy | |
versions of argument1, ..., argumentn as the additional arguments and draws the matplotlib plot on the figure | |
matplotlib_summary creates and returns a tensorflow image summary | |
''' | |
class MatplotlibSummaryOpFactory: | |
def __init__(self): | |
self.counter = 0 | |
def _wrap_pltfn(self, plt_fn): | |
def plot(*args): | |
f = plt.figure() | |
args = [f] + list(args) | |
plt_fn(*args) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
im = scipy.misc.imread(buf) | |
buf.close() | |
return im | |
return plot | |
def __call__(self, plt_fn, *args, name=None): | |
if name is None: | |
self.counter += 1 | |
name = "matplotlib-summary_%d" % self.counter | |
image_tensor = tf.py_func(self._wrap_pltfn(plt_fn), args, tf.uint8) | |
image_tensor.set_shape([None, None, 4]) | |
return tf.summary.image(name, tf.expand_dims(image_tensor, 0)) | |
matplotlib_summary = MatplotlibSummaryOpFactory() | |
''' | |
END matplotlib_summary code | |
''' | |
# Example usage: | |
def plt_mnist(f, digit): | |
# f is the matplotlib figure | |
# digit is a numpy version of the argument passed to matplotlib_summary | |
f.gca().imshow(np.squeeze(digit, -1)) | |
f.gca().set_title("A random MNIST digit") | |
digit = tf.random_normal([28, 28, 1]) | |
summary = matplotlib_summary(plt_mnist, digit, name="mnist-summary") | |
all_summaries = tf.summary.merge_all() | |
summary_writer = tf.summary.FileWriter(".") | |
with tf.Session() as sess: | |
summ = sess.run(all_summaries) | |
summary_writer.add_summary(summ, global_step=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Resultant tensorboard:
This script allows you to use custom matplotlib code to make plots shown in tensorboard. This is handy for extending tensorboard to show data not well represented using tensorboard's built-in visualizations.