Skip to content

Instantly share code, notes, and snippets.

@JossWhittle
Last active April 12, 2020 21:35
Show Gist options
  • Save JossWhittle/c2d49f8fc855f607dfd3a2d7ce4e61b1 to your computer and use it in GitHub Desktop.
Save JossWhittle/c2d49f8fc855f607dfd3a2d7ce4e61b1 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
class LogMetrics(tf.keras.callbacks.Callback):
def __init__(self, log_dir, loss, metrics, steps, dataset, training=False):
super(LogMetrics, self).__init__()
self.log_dir = log_dir
self.metrics = metrics
self.steps = steps
self.dataset = iter(dataset)
self.training = training
self.writer = tf.summary.create_file_writer(log_dir)
self.loss = loss
self.loss_metric = tf.keras.metrics.Mean(name='loss')
self.history = {}
self.history['loss'] = []
for metric in self.metrics:
self.history[metric.name] = []
def on_epoch_end(self, epoch, logs=None):
self.loss_metric.reset_states()
for metric in self.metrics:
metric.reset_states()
for step in range(self.steps):
x, y_true = next(self.dataset)
y_pred = self.model(x, training=self.training)
self.loss_metric.update_state(self.loss(y_true, y_pred))
for metric in self.metrics:
metric.update_state(y_true, y_pred)
with self.writer.as_default():
tf.summary.scalar(self.loss_metric.name, self.loss_metric.result(), step=epoch)
self.history['loss'] += [self.loss_metric.result()]
self.loss_metric.reset_states()
for metric in self.metrics:
tf.summary.scalar(metric.name, metric.result(), step=epoch)
self.history[metric.name] += [metric.result()]
metric.reset_states()
self.writer.flush()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment