Skip to content

Instantly share code, notes, and snippets.

@alantian
Created August 1, 2018 06:06
Show Gist options
  • Save alantian/60603e99b89ac1a54a7fa58f3d1a1e9e to your computer and use it in GitHub Desktop.
Save alantian/60603e99b89ac1a54a7fa58f3d1a1e9e to your computer and use it in GitHub Desktop.
Chainer Extension that reports scalars to tensorboard
import os
import sys
from chainer.training import extension
from chainer.training.extensions import log_report as log_report_module
from chainer.training.extensions import util
import tensorflow as tf
# Logger is from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
class TensorboardLogger(object):
"""Logging in tensorboard without tensorflow ops."""
def __init__(self, log_dir):
"""Creates a summary writer logging to log_dir."""
self.writer = tf.summary.FileWriter(log_dir)
def close(self):
self.writer.close()
self.writer = None
def log_scalar(self, tag, value, step):
"""Log a scalar variable.
Parameter
----------
tag : basestring
Name of the scalar
value
step : int
training iteration
"""
summary = tf.Summary(
value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
self.writer.flush()
class TensorboardReport(extension.Extension):
def __init__(self,
entries=None,
log_report='LogReport',
tensorboard_out='tb'):
self._entries = entries
self._log_report = log_report
self._tensorboard_out = tensorboard_out
self._tensorboard_logger = None
self._log_len = 0 # number of observations already logged
def __call__(self, trainer):
if self._tensorboard_logger is None:
self._tensorboard_logger = TensorboardLogger(
log_dir=os.path.join(trainer.out, self._tensorboard_out))
log_report = self._log_report
if isinstance(log_report, str):
log_report = trainer.get_extension(log_report)
elif isinstance(log_report, log_report_module.LogReport):
log_report(trainer) # update the log report
else:
raise TypeError(
'log report has a wrong type %s' % type(log_report))
log = log_report.log
log_len = self._log_len
while len(log) > log_len:
self._log(log[log_len], log_len)
log_len += 1
self._log_len = log_len
def serialize(self, serializer):
log_report = self._log_report
if isinstance(log_report, log_report_module.LogReport):
log_report.serialize(serializer['_log_report'])
def _log(self, observation, step):
if self._entries is not None:
entries = self._entries
else:
entries = observation.keys()
for entry in entries:
if entry in observation:
self._tensorboard_logger.log_scalar(entry, observation[entry],
step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment