Created
April 9, 2023 20:14
-
-
Save shamilnabiyev/518d0a5929ed28d63403a710256883aa to your computer and use it in GitHub Desktop.
Transfer learning with VGG and Keras [for image classification]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Credits: | |
# Author: Gabriel Cassimiro | |
# Blog post: https://towardsdatascience.com/transfer-learning-with-vgg16-and-keras-50ea161580b4 | |
# GitHub Repo: https://github.com/gabrielcassimiro17/object-detection | |
# | |
import tensorflow_datasets as tfds | |
from tensorflow.keras import layers, models | |
from tensorflow.keras.utils import to_categorical | |
from tensorflow.keras.callbacks import EarlyStopping | |
## Loading images and labels | |
(train_ds, train_labels), (test_ds, test_labels) = tfds.load( | |
"tf_flowers", | |
split=["train[:70%]", "train[:30%]"], ## Train test split | |
batch_size=-1, | |
as_supervised=True, # Include labels | |
) | |
## Resizing images | |
train_ds = tf.image.resize(train_ds, (150, 150)) | |
test_ds = tf.image.resize(test_ds, (150, 150)) | |
## Transforming labels to correct format | |
train_labels = to_categorical(train_labels, num_classes=5) | |
test_labels = to_categorical(test_labels, num_classes=5) | |
from tensorflow.keras.applications.vgg16 import VGG16 | |
from tensorflow.keras.applications.vgg16 import preprocess_input | |
## Loading VGG16 model | |
base_model = VGG16(weights="imagenet", include_top=False, input_shape=train_ds[0].shape) | |
base_model.trainable = False ## Not trainable weights | |
## Preprocessing input | |
train_ds = preprocess_input(train_ds) | |
test_ds = preprocess_input(test_ds) | |
flatten_layer = layers.Flatten() | |
dense_layer_1 = layers.Dense(50, activation='relu') | |
dense_layer_2 = layers.Dense(20, activation='relu') | |
prediction_layer = layers.Dense(5, activation='softmax') | |
model = models.Sequential([ | |
base_model, | |
flatten_layer, | |
dense_layer_1, | |
dense_layer_2, | |
prediction_layer | |
]) | |
model.compile( | |
optimizer='adam', | |
loss='categorical_crossentropy', | |
metrics=['accuracy'], | |
) | |
es = EarlyStopping(monitor='val_accuracy', mode='max', patience=5, restore_best_weights=True) | |
model.fit(train_ds, train_labels, epochs=50, validation_split=0.2, batch_size=32, callbacks=[es]) | |
model.evaluate(test_ds, test_labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment