Skip to content

Instantly share code, notes, and snippets.

@dgrahn
Created June 26, 2020 12:16
Show Gist options
  • Save dgrahn/45acc8ee54c515d6ac28e2696db2a12d to your computer and use it in GitHub Desktop.
Save dgrahn/45acc8ee54c515d6ac28e2696db2a12d to your computer and use it in GitHub Desktop.
Callback to report max of a metric for the Tensorboard HParams plugin
import numpy as np
import tensorflow as tf
from collections import defaultdict
class MaxMetrics(tf.keras.callbacks.Callback):
"""Adds the max metric to the logs at the end of each epoch."""
def __init__(self, metrics):
"""Creates a new MaxMetrics callback
Args:
metrics (list): The metrics to report.
"""
self.metrics = metrics
self.max = defaultdict(lambda: 0.0)
def update_max(self, logs, name):
"""Finds and updates the max value of the specified metric.
Args:
logs (dict): The logs from the epoch
name (str): The metric name
Returns:
float: the max value of the metric
"""
a = logs[name]
b = self.max[name]
max_metric = np.maximum(a, b)
self.max[name] = max_metric
return max_metric
def on_epoch_end(self, epoch, logs=None):
for metric in self.metrics:
logs['max_' + metric] = self.update_max(logs, metric)
logs['val_max_' + metric] = self.update_max(logs, 'val_' + metric)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment