Created
June 26, 2020 12:16
-
-
Save dgrahn/45acc8ee54c515d6ac28e2696db2a12d to your computer and use it in GitHub Desktop.
Callback to report max of a metric for the Tensorboard HParams plugin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from collections import defaultdict | |
class MaxMetrics(tf.keras.callbacks.Callback): | |
"""Adds the max metric to the logs at the end of each epoch.""" | |
def __init__(self, metrics): | |
"""Creates a new MaxMetrics callback | |
Args: | |
metrics (list): The metrics to report. | |
""" | |
self.metrics = metrics | |
self.max = defaultdict(lambda: 0.0) | |
def update_max(self, logs, name): | |
"""Finds and updates the max value of the specified metric. | |
Args: | |
logs (dict): The logs from the epoch | |
name (str): The metric name | |
Returns: | |
float: the max value of the metric | |
""" | |
a = logs[name] | |
b = self.max[name] | |
max_metric = np.maximum(a, b) | |
self.max[name] = max_metric | |
return max_metric | |
def on_epoch_end(self, epoch, logs=None): | |
for metric in self.metrics: | |
logs['max_' + metric] = self.update_max(logs, metric) | |
logs['val_max_' + metric] = self.update_max(logs, 'val_' + metric) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment