Last active
May 12, 2021 06:48
-
-
Save angeligareta/df7e2758485a35b2267acd570d64c69e to your computer and use it in GitHub Desktop.
A better comparison of TensorBoard experiments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import tensorboard as tb | |
import statistics | |
def get_best_epoch_per_model( | |
experiment_id, | |
metrics, | |
format_as_percentage = False, | |
sort_values = False, | |
group_by_model_names = False, | |
show_only_validation = True | |
): | |
# Download tensorboard experiment | |
experiment = tb.data.experimental.ExperimentFromDev(experiment_id) | |
df = experiment.get_scalars() | |
# Pivot table (manually, to allow for nan values) | |
df = df.pivot_table( | |
values=("value"), | |
index=["run", "step"], | |
columns="tag", | |
dropna=False, | |
) | |
# Subselection of relevant metrics and drop any nan (batch results) | |
df = df[metrics].dropna().reset_index() | |
# Calculate harmonic mean (in case of multiple metrics) | |
df['harmonic_mean'] = df[metrics].apply(statistics.harmonic_mean, axis = 1) | |
# Optionally, filter validation runs | |
if show_only_validation: | |
df = df[df['run'].str.contains('val')] | |
# Optionally, update run column to unify tests of same model name | |
if group_by_model_names: | |
df['run'] = df['run'].apply(lambda x: x.split("/")[0]) | |
# Calculate max harmonic mean per group | |
df = df.groupby('run') \ | |
.apply(lambda group: group.nlargest(1, columns='harmonic_mean')) \ | |
.reset_index(drop=True) | |
# Optionally, sort values based on overall performance. | |
if sort_values: | |
df = df.sort_values(by='harmonic_mean', ascending=False).reset_index(drop=True) | |
# Optionally, format metrics as percentage | |
if format_as_percentage: | |
df = df.style.format(dict([(metric, '{:,.2%}'.format) for metric in metrics + ['harmonic_mean']])) | |
return df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment