Created
November 9, 2024 01:10
-
-
Save schappim/f389e5ce90af8e627824e4d7ea05beb9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.applications import MobileNetV2 | |
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization | |
from tensorflow.keras.models import Model | |
import tensorflow_datasets as tfds | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# Constants | |
BATCH_SIZE = 24 | |
IMAGE_SHAPE = (224, 224) | |
INITIAL_LEARNING_RATE = 0.0003 | |
# Load and preprocess dataset | |
train_ds, validation_ds = tfds.load("rock_paper_scissors", | |
split=["train", "test"], | |
as_supervised=True) | |
def preprocess_data(image, label): | |
image = tf.image.resize(image, IMAGE_SHAPE) | |
image = tf.cast(image, tf.float32) / 255.0 | |
# Add color augmentation | |
image = tf.image.random_brightness(image, 0.2) | |
image = tf.image.random_contrast(image, 0.8, 1.2) | |
image = tf.clip_by_value(image, 0.0, 1.0) | |
return image, label | |
# Enhanced data augmentation | |
data_augmentation = tf.keras.Sequential([ | |
tf.keras.layers.RandomFlip("horizontal"), | |
tf.keras.layers.RandomFlip("vertical"), | |
tf.keras.layers.RandomRotation(0.4), | |
tf.keras.layers.RandomZoom(0.2), | |
tf.keras.layers.RandomTranslation(0.2, 0.2), | |
]) | |
def prepare_dataset(dataset, train=False): | |
dataset = dataset.map(preprocess_data, num_parallel_calls=tf.data.AUTOTUNE) | |
if train: | |
dataset = dataset.map( | |
lambda x, y: (data_augmentation(x, training=True), y), | |
num_parallel_calls=tf.data.AUTOTUNE | |
) | |
dataset = dataset.repeat(2) | |
return dataset.cache().shuffle(10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) | |
train_ds = prepare_dataset(train_ds, train=True) | |
validation_ds = prepare_dataset(validation_ds) | |
# Enhanced model architecture | |
def create_model(): | |
base_model = MobileNetV2(input_shape=(224, 224, 3), | |
include_top=False, | |
weights='imagenet') | |
inputs = tf.keras.Input(shape=(224, 224, 3)) | |
# Apply data augmentation only during training | |
x = data_augmentation(inputs) | |
# Pass through base model | |
x = base_model(x) | |
x = GlobalAveragePooling2D()(x) | |
# Enhanced dense layers | |
x = Dense(1024)(x) | |
x = BatchNormalization()(x) | |
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x) | |
x = Dropout(0.5)(x) | |
x = Dense(512)(x) | |
x = BatchNormalization()(x) | |
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x) | |
x = Dropout(0.4)(x) | |
x = Dense(256)(x) | |
x = BatchNormalization()(x) | |
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x) | |
x = Dropout(0.3)(x) | |
x = Dense(128)(x) | |
x = BatchNormalization()(x) | |
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x) | |
outputs = Dense(3, activation='softmax')(x) | |
return Model(inputs, outputs) | |
# Custom callback for model ensemble | |
class ModelEnsembleCallback(tf.keras.callbacks.Callback): | |
def __init__(self, patience=7): | |
super(ModelEnsembleCallback, self).__init__() | |
self.patience = patience | |
self.best_weights = [] | |
self.best_accs = [] | |
self.wait = 0 | |
def on_epoch_end(self, epoch, logs=None): | |
current_acc = logs.get('val_accuracy') | |
if len(self.best_accs) < 3 or current_acc > min(self.best_accs): | |
if len(self.best_accs) >= 3: | |
min_idx = np.argmin(self.best_accs) | |
self.best_accs.pop(min_idx) | |
self.best_weights.pop(min_idx) | |
self.best_accs.append(current_acc) | |
self.best_weights.append(self.model.get_weights()) | |
print(f'\nSaved model weights with accuracy: {current_acc:.4f}') | |
self.wait = 0 | |
else: | |
self.wait += 1 | |
if self.wait >= self.patience: | |
self.model.stop_training = True | |
# First phase training | |
print("Phase 1: Training top layers...") | |
model = create_model() | |
# Use legacy optimizer for better M1/M2 performance | |
model.compile( | |
optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=INITIAL_LEARNING_RATE), | |
loss='sparse_categorical_crossentropy', | |
metrics=['accuracy'] | |
) | |
ensemble_callback = ModelEnsembleCallback(patience=10) | |
history1 = model.fit( | |
train_ds, | |
validation_data=validation_ds, | |
epochs=40, | |
callbacks=[ | |
ensemble_callback, | |
tf.keras.callbacks.ReduceLROnPlateau( | |
monitor='val_accuracy', | |
factor=0.2, | |
patience=5, | |
min_lr=1e-6 | |
) | |
] | |
) | |
# Second phase - fine-tune the model | |
print("\nPhase 2: Fine-tuning model...") | |
base_model = model.layers[2] # Get the base model | |
base_model.trainable = True | |
# Freeze early layers | |
for layer in base_model.layers[:100]: | |
layer.trainable = False | |
# Recompile with lower learning rate | |
model.compile( | |
optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=1e-5), | |
loss='sparse_categorical_crossentropy', | |
metrics=['accuracy'] | |
) | |
history2 = model.fit( | |
train_ds, | |
validation_data=validation_ds, | |
epochs=30, | |
callbacks=[ | |
ensemble_callback, | |
tf.keras.callbacks.ReduceLROnPlateau( | |
monitor='val_accuracy', | |
factor=0.2, | |
patience=3, | |
min_lr=1e-6 | |
) | |
] | |
) | |
# Evaluate with ensemble | |
print("\nEvaluating ensemble model...") | |
val_data = tf.concat([x for x, _ in validation_ds], axis=0) | |
val_labels = tf.concat([y for _, y in validation_ds], axis=0) | |
predictions = [] | |
for weights in ensemble_callback.best_weights: | |
model.set_weights(weights) | |
pred = model.predict(val_data) | |
predictions.append(pred) | |
ensemble_predictions = np.mean(predictions, axis=0) | |
ensemble_accuracy = np.mean(np.argmax(ensemble_predictions, axis=1) == val_labels) | |
print(f"Ensemble validation accuracy: {ensemble_accuracy:.2%}") | |
# Save best model | |
model.set_weights(ensemble_callback.best_weights[np.argmax(ensemble_callback.best_accs)]) | |
model.save('models/mobilenet_best.keras') | |
# Plot training history | |
def plot_training_history(history1, history2): | |
plt.figure(figsize=(15, 5)) | |
# Plot accuracy | |
plt.subplot(1, 2, 1) | |
plt.plot(history1.history['accuracy'] + history2.history['accuracy']) | |
plt.plot(history1.history['val_accuracy'] + history2.history['val_accuracy']) | |
plt.title('Model Accuracy') | |
plt.ylabel('Accuracy') | |
plt.xlabel('Epoch') | |
plt.legend(['Train', 'Validation'], loc='lower right') | |
plt.grid(True) | |
# Plot loss | |
plt.subplot(1, 2, 2) | |
plt.plot(history1.history['loss'] + history2.history['loss']) | |
plt.plot(history1.history['val_loss'] + history2.history['val_loss']) | |
plt.title('Model Loss') | |
plt.ylabel('Loss') | |
plt.xlabel('Epoch') | |
plt.legend(['Train', 'Validation'], loc='upper right') | |
plt.grid(True) | |
plt.tight_layout() | |
plt.savefig('training_history_final.png') | |
plt.close() | |
plot_training_history(history1, history2) | |
# Print all best accuracies achieved | |
print("\nTop validation accuracies achieved:") | |
for i, acc in enumerate(sorted(ensemble_callback.best_accs, reverse=True)): | |
print(f"Model {i+1}: {acc:.2%}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment