This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
EPOCHS = 10 | |
train_images_scaled = train_images_gr / 255. | |
model.fit(train_images_scaled, train_labels, validation_split=0.1, epochs=EPOCHS) | |
# Output | |
Train on 54000 samples, validate on 6000 samples | |
Epoch 1/10 | |
54000/54000 [====] - 7s 122us/sample - loss: 0.4614 - acc: 0.8323 - val_loss: 0.3462 - val_acc: 0.8725 | |
Epoch 2/10 | |
54000/54000 [====] - 5s 86us/sample - loss: 0.3073 - acc: 0.8892 - val_loss: 0.2825 - val_acc: 0.8950 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
INPUT_SHAPE = (28, 28, 1) | |
def create_cnn_architecture_model1(input_shape): | |
inp = keras.layers.Input(shape=input_shape) | |
conv1 = keras.layers.Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1), | |
activation='relu', padding='same')(inp) | |
pool1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1) | |
conv2 = keras.layers.Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), | |
activation='relu', padding='same')(pool1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# reshape for feeding into the model | |
train_images_gr = train_images.reshape(train_images.shape[0], 28, 28, 1) | |
test_images_gr = test_images.reshape(test_images.shape[0], 28, 28, 1) | |
print('\nTrain_images.shape: {}, of {}'.format(train_images_gr.shape, train_images_gr.dtype)) | |
print('Test_images.shape: {}, of {}'.format(test_images_gr.shape, test_images_gr.dtype)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fashion_mnist = keras.datasets.fashion_mnist | |
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() | |
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', | |
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] | |
print('\nTrain_images.shape: {}, of {}'.format(train_images.shape, train_images.dtype)) | |
print('Test_images.shape: {}, of {}'.format(test_images.shape, test_images.dtype)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# to prevent unnecessary warnings | |
import warnings | |
warnings.simplefilter(action='ignore', category=FutureWarning) | |
# TensorFlow and tf.keras | |
import tensorflow as tf | |
from tensorflow import keras | |
# Helper libraries | |
import numpy as np |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fig, ax = plt.subplots(2, 4, figsize=(12, 6)) | |
for idx, img_idx in enumerate([15, 123, 230, 340, 450, 560, 670]): | |
id1 = 1 if idx > 3 else 0 | |
id2 = idx % 4 | |
predicted_label = class_label_mapping[ | |
np.argmax( | |
resnet_ft_model4.predict( | |
np.array([test_data_X[img_idx]]) | |
),axis=1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
with tf.device('/cpu:0'): | |
resnet_ft_model4 = keras.models.load_model('resnet_finetune_full_models/resnet50_finetune_full_seti_epoch_497.h5') | |
evaluate_model_results(resnet_ft_model4, test_data_X, test_data_y, | |
class_label_mapping, class_labels) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vgg19_full_ft_model1 = keras.models.load_model('vgg19_finetune_full_models/vgg19_finetune_full_seti.h5') | |
evaluate_model_results(vgg19_full_ft_model1, test_data_X, test_data_y, | |
class_label_mapping, class_labels) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vgg19_partial_ft_model3 = keras.models.load_model('vgg19_finetune_partial_models/vgg19_finetune_partial_seti_epoch_99.h5') | |
evaluate_model_results(vgg19_partial_ft_model3, test_data_X, test_data_y, | |
class_label_mapping, class_labels) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from sklearn.metrics import classification_report, confusion_matrix | |
def evaluate_model_results(model, test_data, test_labels, | |
class_label_mapping, class_labels): | |
predictions = model.predict(test_data, verbose=1) | |
prediction_labels = [class_label_mapping[idx] for idx in predictions.argmax(axis=1)] | |
print(classification_report(y_true=test_labels, y_pred=prediction_labels)) | |
return pd.DataFrame(confusion_matrix(y_true=test_labels, y_pred=prediction_labels, labels=class_labels), | |
index=class_labels, columns=class_labels) |