Skip to content

Instantly share code, notes, and snippets.

@dipanjanS
Created September 20, 2019 09:13
Show Gist options
  • Save dipanjanS/d4d1b824f4a4d40f9fe1ff4cd00ce25a to your computer and use it in GitHub Desktop.
Save dipanjanS/d4d1b824f4a4d40f9fe1ff4cd00ce25a to your computer and use it in GitHub Desktop.
# save model
if not os.path.isdir('model_weights/'):
os.mkdir('model_weights/')
model2.save_weights(filepath='model_weights/cnn_model2_wt.h5', overwrite=True)
# load model (can be used in the future as needed once trained)
model2 = create_cnn_architecture_model2(input_shape=INPUT_SHAPE_RN)
model2.load_weights('model_weights/cnn_model2_wt.h5')
# predict and evaluate on test dataset
test_images_3ch_scaled = test_images_3ch / 255.
predictions = model2.predict(test_images_3ch_scaled)
prediction_labels = np.argmax(predictions, axis=1)
print(classification_report(test_labels, prediction_labels,
target_names=class_names))
pd.DataFrame(confusion_matrix(test_labels, prediction_labels),
index=class_names, columns=class_names)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment