Created
March 6, 2024 16:49
-
-
Save tropicbliss/5fadc680e322ce12135e4f441bbb3a3c to your computer and use it in GitHub Desktop.
Binary classification with ResNet50
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D | |
from tensorflow.keras.optimizers import Adam | |
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input | |
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau | |
import os | |
def custom_preprocess_function(x): | |
x = tf.image.rgb_to_grayscale(x) | |
x = tf.repeat(x, repeats=3, axis=-1) # Replicate grayscale image across 3 channels | |
return preprocess_input(x) | |
# Set paths to your data | |
train_directory = 'train' | |
validation_directory = 'validation' | |
test_directory = "test" | |
# Initialize the base model, excluding the top (fully connected) layers | |
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) | |
# Freeze the layers of the base model to prevent them from being updated during the first training phase | |
for layer in base_model.layers: | |
layer.trainable = False | |
# Create the model | |
model = Sequential([ | |
base_model, | |
GlobalAveragePooling2D(), | |
Dense(1024, activation='relu'), | |
Dense(1, activation='sigmoid') # Binary classification | |
]) | |
# Compile the model | |
model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy']) | |
# Prepare data augmentation configuration | |
train_datagen = ImageDataGenerator( | |
preprocessing_function=custom_preprocess_function, | |
shear_range=0.2, | |
zoom_range=0.2, | |
horizontal_flip=True) | |
validation_datagen = ImageDataGenerator(preprocessing_function=custom_preprocess_function) | |
# Prepare data generators | |
train_generator = train_datagen.flow_from_directory( | |
train_directory, | |
target_size=(224, 224), | |
batch_size=32, | |
class_mode='binary') | |
validation_generator = validation_datagen.flow_from_directory( | |
validation_directory, | |
target_size=(224, 224), | |
batch_size=32, | |
class_mode='binary') | |
# Define the checkpoint directory and file name | |
checkpoint_path = "training_1/cp.ckpt" | |
checkpoint_dir = os.path.dirname(checkpoint_path) | |
# Create a callback that saves the model's weights | |
cp_callback = ModelCheckpoint(filepath=checkpoint_path, | |
save_weights_only=True, | |
verbose=1, | |
save_best_only=True, | |
monitor='val_loss') # Saves only the best model | |
reduce_lr = ReduceLROnPlateau( | |
monitor='val_loss', | |
factor=0.2, | |
patience=5, | |
min_lr=0.001) | |
# Train the model | |
model.fit( | |
train_generator, | |
steps_per_epoch=train_generator.samples // 32, | |
epochs=10, | |
validation_data=validation_generator, | |
validation_steps=validation_generator.samples // 32, | |
callbacks=[cp_callback, reduce_lr]) | |
test_datagen = ImageDataGenerator(preprocessing_function=custom_preprocess_function) | |
test_generator = test_datagen.flow_from_directory( | |
test_directory, | |
target_size=(224, 224), | |
batch_size=32, | |
class_mode='binary') | |
model.evaluate(test_generator) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment