Skip to content

Instantly share code, notes, and snippets.

@vikramsoni2
Created June 28, 2021 04:02
Show Gist options
  • Save vikramsoni2/beba4f0b6fb0fbdcd5781134e21fe8af to your computer and use it in GitHub Desktop.
Save vikramsoni2/beba4f0b6fb0fbdcd5781134e21fe8af to your computer and use it in GitHub Desktop.
plotting train vs valid metrics in line plot
metrics_names = ['loss', 'accuracy', 'top10_accuracy']
plt.figure(figsize=(14,4))
sns.set_style('whitegrid')
for i in range(len(metrics_names)):
ax = plt.subplot(1, len(metrics_names), i+1)
ax.plot(history.history[metrics_names[i]], label="train")
ax.plot(history.history['val_'+ metrics_names[i]], label="valid")
ax.title.set_text(metrics_names[i])
lines, labels = ax.get_legend_handles_labels()
plt.gcf().legend(lines, labels, loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.1), fancybox=True)
plt.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment