-
-
Save whyboris/91ee793ddc92cf1e824978cf31bb790c to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt | |
# Plot model history more easily | |
# when plotting, smooth out the points by some factor (0.5 = rough, 0.99 = smooth) | |
# method taken from `Deep Learning with Python` by François Chollet | |
def smooth_curve(points, factor=0.75): | |
smoothed_points = [] | |
for point in points: | |
if smoothed_points: | |
previous = smoothed_points[-1] | |
smoothed_points.append(previous * factor + point * (1 - factor)) | |
else: | |
smoothed_points.append(point) | |
return smoothed_points | |
def set_plot_history_data(ax, history, which_graph): | |
if which_graph == 'acc': | |
train = smooth_curve(history.history['acc']) | |
valid = smooth_curve(history.history['val_acc']) | |
if which_graph == 'loss': | |
train = smooth_curve(history.history['loss']) | |
valid = smooth_curve(history.history['val_loss']) | |
plt.xkcd() # make plots look like xkcd | |
epochs = range(1, len(train) + 1) | |
trim = 5 # remove first 5 epochs | |
# when graphing loss the first few epochs may skew the (loss) graph | |
ax.plot(epochs[trim:], train[trim:], 'dodgerblue', label=('Training')) | |
ax.plot(epochs[trim:], train[trim:], 'dodgerblue', linewidth=15, alpha=0.1) | |
ax.plot(epochs[trim:], valid[trim:], 'g', label=('Validation')) | |
ax.plot(epochs[trim:], valid[trim:], 'g', linewidth=15, alpha=0.1) | |
def get_max_validation_accuracy(history): | |
validation = smooth_curve(history.history['val_acc']) | |
ymax = max(validation) | |
return 'Max validation accuracy ≈ ' + str(round(ymax, 3)*100) + '%' | |
def plot_history(history): | |
fig, (ax1, ax2) = plt.subplots(nrows=2, | |
ncols=1, | |
figsize=(10, 6), | |
sharex=True, | |
gridspec_kw = {'height_ratios':[5, 2]}) | |
set_plot_history_data(ax1, history, 'acc') | |
set_plot_history_data(ax2, history, 'loss') | |
# Accuracy graph | |
ax1.set_ylabel('Accuracy') | |
ax1.set_ylim(bottom=0.5, top=1) | |
ax1.legend(loc="lower right") | |
ax1.spines['top'].set_visible(False) | |
ax1.spines['right'].set_visible(False) | |
ax1.xaxis.set_ticks_position('none') | |
ax1.spines['bottom'].set_visible(False) | |
# max accuracty text | |
plt.text(0.97, | |
0.97, | |
get_max_validation_accuracy(history), | |
horizontalalignment='right', | |
verticalalignment='top', | |
transform=ax1.transAxes, | |
fontsize=12) | |
# Loss graph | |
ax2.set_ylabel('Loss') | |
ax2.set_yticks([]) | |
ax2.plot(legend=False) | |
ax2.set_xlabel('Epochs') | |
ax2.spines['top'].set_visible(False) | |
ax2.spines['right'].set_visible(False) | |
plt.tight_layout() | |
# how to use: | |
# assuming you are using Keras | |
history = model.fit(x, y, ...) | |
plot_history(history) |
This is now a pip package: https://pypi.org/project/keras-hist-graph/
Install with pip install keras-hist-graph
Use thus:
from keras_hist_graph import plot_history
history = model.fit(x, y, ...) # standard Keras training code
plot_history(history)
This is now a pip package: https://pypi.org/project/keras-hist-graph/
Install withpip install keras-hist-graph
Use thus:from keras_hist_graph import plot_history history = model.fit(x, y, ...) # standard Keras training code plot_history(history)
Getting this error:
KeyError Traceback (most recent call last)
in
1 from keras_hist_graph import plot_history
2
----> 3 plot_history(history)
~\anaconda3\envs\tf_gpu\lib\site-packages\keras_hist_graph\keras_hist_graph.py in plot_history(history, start_epoch, smooth_factor, xkcd, fig_size, min_accuracy)
68 )
69
---> 70 set_plot_history_data(ax1, history, "acc", start_epoch, smooth_factor, xkcd)
71
72 set_plot_history_data(ax2, history, "loss", start_epoch, smooth_factor, xkcd)
~\anaconda3\envs\tf_gpu\lib\site-packages\keras_hist_graph\keras_hist_graph.py in set_plot_history_data(ax, history, which_graph, start_epoch, smooth_factor, xkcd)
21
22 if which_graph == "acc":
---> 23 train = smooth_curve(history.history["acc"], smooth_factor)
24 valid = smooth_curve(history.history["val_acc"], smooth_factor)
25
KeyError: 'acc'
Thank you for the bug report. At some point in the last year something changed -- I believe Keras no longer has acc
❓ 🤔
I have this code as a PyPI package: https://pypi.org/project/keras-hist-graph/
And there's a repository for it too: https://github.com/whyboris/keras-hist-graph
I've not been doing much Keras lately, so I might not fix it for a while 🤷 -- if you happen to figure something out - please open a PR 🙇
Results in a graph like this: