Skip to content

Instantly share code, notes, and snippets.

@lmassaron
Last active February 18, 2019 13:24
Show Gist options
  • Select an option

  • Save lmassaron/79d3a3aa6e1bd1d556f97ee437d80c1e to your computer and use it in GitHub Desktop.

Select an option

Save lmassaron/79d3a3aa6e1bd1d556f97ee437d80c1e to your computer and use it in GitHub Desktop.
plot_keras_history
def plot_keras_history(history, measures):
"""
history: Keras training history
measures = list of names of measures
"""
rows = len(measures) // 2 + len(measures) % 2
fig, panels = plt.subplots(rows, 2, figsize=(15, 5))
plt.subplots_adjust(top = 0.99, bottom=0.01, hspace=0.4, wspace=0.2)
try:
panels = [item for sublist in panels for item in sublist]
except:
pass
for k, measure in enumerate(measures):
panel = panels[k]
panel.set_title(measure + ' history')
panel.plot(history.epoch, history.history[measure], label="Train "+measure)
panel.plot(history.epoch, history.history["val_"+measure], label="Validation "+measure)
panel.set(xlabel='epochs', ylabel=measure)
panel.legend()
plt.show(fig)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment