Skip to content

Instantly share code, notes, and snippets.

@tropicbliss
Created March 6, 2024 16:49
Show Gist options
  • Save tropicbliss/5fadc680e322ce12135e4f441bbb3a3c to your computer and use it in GitHub Desktop.
Save tropicbliss/5fadc680e322ce12135e4f441bbb3a3c to your computer and use it in GitHub Desktop.
Binary classification with ResNet50
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
import os
def custom_preprocess_function(x):
x = tf.image.rgb_to_grayscale(x)
x = tf.repeat(x, repeats=3, axis=-1) # Replicate grayscale image across 3 channels
return preprocess_input(x)
# Set paths to your data
train_directory = 'train'
validation_directory = 'validation'
test_directory = "test"
# Initialize the base model, excluding the top (fully connected) layers
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# Freeze the layers of the base model to prevent them from being updated during the first training phase
for layer in base_model.layers:
layer.trainable = False
# Create the model
model = Sequential([
base_model,
GlobalAveragePooling2D(),
Dense(1024, activation='relu'),
Dense(1, activation='sigmoid') # Binary classification
])
# Compile the model
model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])
# Prepare data augmentation configuration
train_datagen = ImageDataGenerator(
preprocessing_function=custom_preprocess_function,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
validation_datagen = ImageDataGenerator(preprocessing_function=custom_preprocess_function)
# Prepare data generators
train_generator = train_datagen.flow_from_directory(
train_directory,
target_size=(224, 224),
batch_size=32,
class_mode='binary')
validation_generator = validation_datagen.flow_from_directory(
validation_directory,
target_size=(224, 224),
batch_size=32,
class_mode='binary')
# Define the checkpoint directory and file name
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1,
save_best_only=True,
monitor='val_loss') # Saves only the best model
reduce_lr = ReduceLROnPlateau(
monitor='val_loss',
factor=0.2,
patience=5,
min_lr=0.001)
# Train the model
model.fit(
train_generator,
steps_per_epoch=train_generator.samples // 32,
epochs=10,
validation_data=validation_generator,
validation_steps=validation_generator.samples // 32,
callbacks=[cp_callback, reduce_lr])
test_datagen = ImageDataGenerator(preprocessing_function=custom_preprocess_function)
test_generator = test_datagen.flow_from_directory(
test_directory,
target_size=(224, 224),
batch_size=32,
class_mode='binary')
model.evaluate(test_generator)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment