Last active
May 12, 2021 02:47
-
-
Save whyboris/91ee793ddc92cf1e824978cf31bb790c to your computer and use it in GitHub Desktop.
Keras Loss & Accuracy Plot Helper Function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
# Plot model history more easily | |
# when plotting, smooth out the points by some factor (0.5 = rough, 0.99 = smooth) | |
# method taken from `Deep Learning with Python` by François Chollet | |
def smooth_curve(points, factor=0.75): | |
smoothed_points = [] | |
for point in points: | |
if smoothed_points: | |
previous = smoothed_points[-1] | |
smoothed_points.append(previous * factor + point * (1 - factor)) | |
else: | |
smoothed_points.append(point) | |
return smoothed_points | |
def set_plot_history_data(ax, history, which_graph): | |
if which_graph == 'acc': | |
train = smooth_curve(history.history['acc']) | |
valid = smooth_curve(history.history['val_acc']) | |
if which_graph == 'loss': | |
train = smooth_curve(history.history['loss']) | |
valid = smooth_curve(history.history['val_loss']) | |
plt.xkcd() # make plots look like xkcd | |
epochs = range(1, len(train) + 1) | |
trim = 5 # remove first 5 epochs | |
# when graphing loss the first few epochs may skew the (loss) graph | |
ax.plot(epochs[trim:], train[trim:], 'dodgerblue', label=('Training')) | |
ax.plot(epochs[trim:], train[trim:], 'dodgerblue', linewidth=15, alpha=0.1) | |
ax.plot(epochs[trim:], valid[trim:], 'g', label=('Validation')) | |
ax.plot(epochs[trim:], valid[trim:], 'g', linewidth=15, alpha=0.1) | |
def get_max_validation_accuracy(history): | |
validation = smooth_curve(history.history['val_acc']) | |
ymax = max(validation) | |
return 'Max validation accuracy ≈ ' + str(round(ymax, 3)*100) + '%' | |
def plot_history(history): | |
fig, (ax1, ax2) = plt.subplots(nrows=2, | |
ncols=1, | |
figsize=(10, 6), | |
sharex=True, | |
gridspec_kw = {'height_ratios':[5, 2]}) | |
set_plot_history_data(ax1, history, 'acc') | |
set_plot_history_data(ax2, history, 'loss') | |
# Accuracy graph | |
ax1.set_ylabel('Accuracy') | |
ax1.set_ylim(bottom=0.5, top=1) | |
ax1.legend(loc="lower right") | |
ax1.spines['top'].set_visible(False) | |
ax1.spines['right'].set_visible(False) | |
ax1.xaxis.set_ticks_position('none') | |
ax1.spines['bottom'].set_visible(False) | |
# max accuracty text | |
plt.text(0.97, | |
0.97, | |
get_max_validation_accuracy(history), | |
horizontalalignment='right', | |
verticalalignment='top', | |
transform=ax1.transAxes, | |
fontsize=12) | |
# Loss graph | |
ax2.set_ylabel('Loss') | |
ax2.set_yticks([]) | |
ax2.plot(legend=False) | |
ax2.set_xlabel('Epochs') | |
ax2.spines['top'].set_visible(False) | |
ax2.spines['right'].set_visible(False) | |
plt.tight_layout() | |
# how to use: | |
# assuming you are using Keras | |
history = model.fit(x, y, ...) | |
plot_history(history) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you for the bug report. At some point in the last year something changed -- I believe Keras no longer has
acc
❓ 🤔I have this code as a PyPI package: https://pypi.org/project/keras-hist-graph/
And there's a repository for it too: https://github.com/whyboris/keras-hist-graph
I've not been doing much Keras lately, so I might not fix it for a while 🤷 -- if you happen to figure something out - please open a PR 🙇