Last active
July 25, 2019 08:44
-
-
Save branislav1991/d497a2c40cb9006376e4d51024db70bb to your computer and use it in GitHub Desktop.
Tensorboard API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
import numpy as np | |
from PIL import Image | |
import tensorflow as tf | |
class Tensorboard: | |
def __init__(self, logdir): | |
self.writer = tf.summary.FileWriter(logdir) | |
def close(self): | |
self.writer.close() | |
def log_scalar(self, tag, value, global_step): | |
summary = tf.Summary() | |
summary.value.add(tag=tag, simple_value=value) | |
self.writer.add_summary(summary, global_step=global_step) | |
self.writer.flush() | |
def log_histogram(self, tag, values, global_step, bins): | |
counts, bin_edges = np.histogram(values, bins=bins) | |
hist = tf.HistogramProto() | |
hist.min = float(np.min(values)) | |
hist.max = float(np.max(values)) | |
hist.num = int(np.prod(values.shape)) | |
hist.sum = float(np.sum(values)) | |
hist.sum_squares = float(np.sum(values**2)) | |
bin_edges = bin_edges[1:] | |
for edge in bin_edges: | |
hist.bucket_limit.append(edge) | |
for c in counts: | |
hist.bucket.append(c) | |
summary = tf.Summary() | |
summary.value.add(tag=tag, histo=hist) | |
self.writer.add_summary(summary, global_step=global_step) | |
self.writer.flush() | |
def log_image(self, tag, img, global_step): | |
s = io.BytesIO() | |
Image.fromarray(img).save(s, format='png') | |
img_summary = tf.Summary.Image(encoded_image_string=s.getvalue(), | |
height=img.shape[0], | |
width=img.shape[1]) | |
summary = tf.Summary() | |
summary.value.add(tag=tag, image=img_summary) | |
self.writer.add_summary(summary, global_step=global_step) | |
self.writer.flush() | |
def log_plot(self, tag, figure, global_step): | |
plot_buf = io.BytesIO() | |
figure.savefig(plot_buf, format='png') | |
plot_buf.seek(0) | |
img = Image.open(plot_buf) | |
img_ar = np.array(img) | |
img_summary = tf.Summary.Image(encoded_image_string=plot_buf.getvalue(), | |
height=img_ar.shape[0], | |
width=img_ar.shape[1]) | |
summary = tf.Summary() | |
summary.value.add(tag=tag, image=img_summary) | |
self.writer.add_summary(summary, global_step=global_step) | |
self.writer.flush() | |
if __name__ == '__main__': | |
tensorboard = Tensorboard('logs') | |
x = np.arange(1,101) | |
y = 20 + 3 * x + np.random.random(100) * 100 | |
# Log simple values | |
for i in range(0,100): | |
tensorboard.log_scalar('value', y[i], i) | |
# Log images | |
img = skimage.io.imread(r'C:\Users\212551241\Downloads\example_img.jpg') | |
tensorboard.log_image('example_image', img, 0) | |
# Log plots | |
fig = plt.figure() | |
plt.plot(x, y, 'o') | |
plt.close() | |
tensorboard.log_plot('example_plot', fig, 0) | |
# Log histograms | |
rng = np.random.RandomState(10) | |
a = np.hstack((rng.normal(size=1000), rng.normal(loc=5, scale=2, size=1000))) | |
tensorboard.log_histogram('example_hist', a, 0, 'auto') | |
tensorboard.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Great Work!!!
However it has some wrong for me. I made some change and fixed it.
My environment: ubuntu 16.04 ; tensorboard 1.8.0
I import matplotlib package like:
import matplotlib matplotlib.use("TkAgg") import matplotlib.pyplot as plt
and
I adjusted the position of the 86th and 87th lines of code like:
tensorboard.log_plot('example_plot', fig, 0) plt.close()
This changes can fix some error for me.