Created
August 1, 2018 06:06
-
-
Save alantian/60603e99b89ac1a54a7fa58f3d1a1e9e to your computer and use it in GitHub Desktop.
Chainer Extension that reports scalars to tensorboard
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
from chainer.training import extension | |
from chainer.training.extensions import log_report as log_report_module | |
from chainer.training.extensions import util | |
import tensorflow as tf | |
# Logger is from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 | |
class TensorboardLogger(object): | |
"""Logging in tensorboard without tensorflow ops.""" | |
def __init__(self, log_dir): | |
"""Creates a summary writer logging to log_dir.""" | |
self.writer = tf.summary.FileWriter(log_dir) | |
def close(self): | |
self.writer.close() | |
self.writer = None | |
def log_scalar(self, tag, value, step): | |
"""Log a scalar variable. | |
Parameter | |
---------- | |
tag : basestring | |
Name of the scalar | |
value | |
step : int | |
training iteration | |
""" | |
summary = tf.Summary( | |
value=[tf.Summary.Value(tag=tag, simple_value=value)]) | |
self.writer.add_summary(summary, step) | |
self.writer.flush() | |
class TensorboardReport(extension.Extension): | |
def __init__(self, | |
entries=None, | |
log_report='LogReport', | |
tensorboard_out='tb'): | |
self._entries = entries | |
self._log_report = log_report | |
self._tensorboard_out = tensorboard_out | |
self._tensorboard_logger = None | |
self._log_len = 0 # number of observations already logged | |
def __call__(self, trainer): | |
if self._tensorboard_logger is None: | |
self._tensorboard_logger = TensorboardLogger( | |
log_dir=os.path.join(trainer.out, self._tensorboard_out)) | |
log_report = self._log_report | |
if isinstance(log_report, str): | |
log_report = trainer.get_extension(log_report) | |
elif isinstance(log_report, log_report_module.LogReport): | |
log_report(trainer) # update the log report | |
else: | |
raise TypeError( | |
'log report has a wrong type %s' % type(log_report)) | |
log = log_report.log | |
log_len = self._log_len | |
while len(log) > log_len: | |
self._log(log[log_len], log_len) | |
log_len += 1 | |
self._log_len = log_len | |
def serialize(self, serializer): | |
log_report = self._log_report | |
if isinstance(log_report, log_report_module.LogReport): | |
log_report.serialize(serializer['_log_report']) | |
def _log(self, observation, step): | |
if self._entries is not None: | |
entries = self._entries | |
else: | |
entries = observation.keys() | |
for entry in entries: | |
if entry in observation: | |
self._tensorboard_logger.log_scalar(entry, observation[entry], | |
step) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment