Last active
May 17, 2024 19:58
-
-
Save idontcalculate/41067003a3015bf0c10883d75e85f70a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.applications import VGG16 | |
from tensorflow.keras import layers, models, optimizers | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
from tensorflow.keras.callbacks import EarlyStopping | |
# Load the pre-trained VGG16 model without the top layer | |
base_model = VGG16(input_shape=(256, 256, 3), include_top=False, weights='imagenet') | |
# Freeze the convolutional base | |
base_model.trainable = False | |
# Add new layers on top of the pre-trained base | |
model = models.Sequential([ | |
base_model, | |
layers.Flatten(), | |
layers.Dense(512, activation='relu'), | |
layers.Dropout(0.5), | |
layers.Dense(1, activation='sigmoid') | |
]) | |
# Compile the model | |
model.compile(optimizer=optimizers.Adam(learning_rate=0.001), # Initial learning rate | |
loss='binary_crossentropy', | |
metrics=['accuracy']) | |
# Data augmentation | |
train_datagen = ImageDataGenerator( | |
rescale=1./255, | |
rotation_range=40, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
shear_range=0.2, | |
zoom_range=0.2, | |
horizontal_flip=True, | |
fill_mode='nearest' | |
) | |
test_datagen = ImageDataGenerator(rescale=1./255) | |
# Load the training and validation data | |
train_dir = 'MURA-v1.1/train/XR_WRIST' | |
validation_dir = 'MURA-v1.1/valid/XR_WRIST' | |
train_generator = train_datagen.flow_from_directory( | |
train_dir, | |
target_size=(256, 256), | |
batch_size=32, | |
class_mode='binary' | |
) | |
validation_generator = test_datagen.flow_from_directory( | |
validation_dir, | |
target_size=(256, 256), | |
batch_size=32, | |
class_mode='binary' | |
) | |
# Set up early stopping | |
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) | |
# Train the model with data augmentation | |
history = model.fit( | |
train_generator, | |
epochs=10, # Initial number of epochs | |
validation_data=validation_generator, | |
callbacks=[early_stopping] | |
) | |
# Fine-tuning the model | |
# Unfreeze the top 4 convolutional layers of the VGG16 base model | |
base_model.trainable = True | |
for layer in base_model.layers[:-4]: | |
layer.trainable = False | |
# Compile the model with a lower learning rate | |
model.compile(optimizer=optimizers.Adam(learning_rate=1e-5), # Lower learning rate for fine-tuning | |
loss='binary_crossentropy', | |
metrics=['accuracy']) | |
# Continue training the model with fine-tuning | |
history_fine = model.fit( | |
train_generator, | |
epochs=10, # Additional epochs for fine-tuning | |
validation_data=validation_generator, | |
callbacks=[early_stopping] | |
) | |
# Save the trained model | |
from tensorflow.keras.models import save_model | |
save_model(model, "modelVGG16.h5") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment