Created
March 17, 2024 22:04
-
-
Save marhar/ca9a15d32aa6673ca6c14f729f4c59fd to your computer and use it in GitHub Desktop.
plotting keras model fitting statistics
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# model.fit(img_iter, | |
# epochs=10, | |
# steps_per_epoch=len(x_train)/batch_size, # Run same number of steps we would if we were not using a generator. | |
# validation_data=(x_valid, y_valid)) | |
import matplotlib.pyplot as plt | |
# Plotting the training loss | |
plt.plot(model.history.history['loss'], label='Training Loss') | |
# If you also want to plot the validation loss | |
plt.plot(model.history.history['val_loss'], label='Validation Loss') | |
# Adding title and labels | |
plt.title('Model Loss') | |
plt.ylabel('Loss') | |
plt.xlabel('Epoch') | |
# Adding legend | |
plt.legend() | |
# Show the plot | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment