Skip to content

Instantly share code, notes, and snippets.

@schappim
Created November 9, 2024 01:10
Show Gist options
  • Save schappim/f389e5ce90af8e627824e4d7ea05beb9 to your computer and use it in GitHub Desktop.
Save schappim/f389e5ce90af8e627824e4d7ea05beb9 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.models import Model
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
# Constants
BATCH_SIZE = 24
IMAGE_SHAPE = (224, 224)
INITIAL_LEARNING_RATE = 0.0003
# Load and preprocess dataset
train_ds, validation_ds = tfds.load("rock_paper_scissors",
split=["train", "test"],
as_supervised=True)
def preprocess_data(image, label):
image = tf.image.resize(image, IMAGE_SHAPE)
image = tf.cast(image, tf.float32) / 255.0
# Add color augmentation
image = tf.image.random_brightness(image, 0.2)
image = tf.image.random_contrast(image, 0.8, 1.2)
image = tf.clip_by_value(image, 0.0, 1.0)
return image, label
# Enhanced data augmentation
data_augmentation = tf.keras.Sequential([
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomFlip("vertical"),
tf.keras.layers.RandomRotation(0.4),
tf.keras.layers.RandomZoom(0.2),
tf.keras.layers.RandomTranslation(0.2, 0.2),
])
def prepare_dataset(dataset, train=False):
dataset = dataset.map(preprocess_data, num_parallel_calls=tf.data.AUTOTUNE)
if train:
dataset = dataset.map(
lambda x, y: (data_augmentation(x, training=True), y),
num_parallel_calls=tf.data.AUTOTUNE
)
dataset = dataset.repeat(2)
return dataset.cache().shuffle(10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
train_ds = prepare_dataset(train_ds, train=True)
validation_ds = prepare_dataset(validation_ds)
# Enhanced model architecture
def create_model():
base_model = MobileNetV2(input_shape=(224, 224, 3),
include_top=False,
weights='imagenet')
inputs = tf.keras.Input(shape=(224, 224, 3))
# Apply data augmentation only during training
x = data_augmentation(inputs)
# Pass through base model
x = base_model(x)
x = GlobalAveragePooling2D()(x)
# Enhanced dense layers
x = Dense(1024)(x)
x = BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = Dropout(0.5)(x)
x = Dense(512)(x)
x = BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = Dropout(0.4)(x)
x = Dense(256)(x)
x = BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = Dropout(0.3)(x)
x = Dense(128)(x)
x = BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
outputs = Dense(3, activation='softmax')(x)
return Model(inputs, outputs)
# Custom callback for model ensemble
class ModelEnsembleCallback(tf.keras.callbacks.Callback):
def __init__(self, patience=7):
super(ModelEnsembleCallback, self).__init__()
self.patience = patience
self.best_weights = []
self.best_accs = []
self.wait = 0
def on_epoch_end(self, epoch, logs=None):
current_acc = logs.get('val_accuracy')
if len(self.best_accs) < 3 or current_acc > min(self.best_accs):
if len(self.best_accs) >= 3:
min_idx = np.argmin(self.best_accs)
self.best_accs.pop(min_idx)
self.best_weights.pop(min_idx)
self.best_accs.append(current_acc)
self.best_weights.append(self.model.get_weights())
print(f'\nSaved model weights with accuracy: {current_acc:.4f}')
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.model.stop_training = True
# First phase training
print("Phase 1: Training top layers...")
model = create_model()
# Use legacy optimizer for better M1/M2 performance
model.compile(
optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=INITIAL_LEARNING_RATE),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
ensemble_callback = ModelEnsembleCallback(patience=10)
history1 = model.fit(
train_ds,
validation_data=validation_ds,
epochs=40,
callbacks=[
ensemble_callback,
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_accuracy',
factor=0.2,
patience=5,
min_lr=1e-6
)
]
)
# Second phase - fine-tune the model
print("\nPhase 2: Fine-tuning model...")
base_model = model.layers[2] # Get the base model
base_model.trainable = True
# Freeze early layers
for layer in base_model.layers[:100]:
layer.trainable = False
# Recompile with lower learning rate
model.compile(
optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=1e-5),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
history2 = model.fit(
train_ds,
validation_data=validation_ds,
epochs=30,
callbacks=[
ensemble_callback,
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_accuracy',
factor=0.2,
patience=3,
min_lr=1e-6
)
]
)
# Evaluate with ensemble
print("\nEvaluating ensemble model...")
val_data = tf.concat([x for x, _ in validation_ds], axis=0)
val_labels = tf.concat([y for _, y in validation_ds], axis=0)
predictions = []
for weights in ensemble_callback.best_weights:
model.set_weights(weights)
pred = model.predict(val_data)
predictions.append(pred)
ensemble_predictions = np.mean(predictions, axis=0)
ensemble_accuracy = np.mean(np.argmax(ensemble_predictions, axis=1) == val_labels)
print(f"Ensemble validation accuracy: {ensemble_accuracy:.2%}")
# Save best model
model.set_weights(ensemble_callback.best_weights[np.argmax(ensemble_callback.best_accs)])
model.save('models/mobilenet_best.keras')
# Plot training history
def plot_training_history(history1, history2):
plt.figure(figsize=(15, 5))
# Plot accuracy
plt.subplot(1, 2, 1)
plt.plot(history1.history['accuracy'] + history2.history['accuracy'])
plt.plot(history1.history['val_accuracy'] + history2.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')
plt.grid(True)
# Plot loss
plt.subplot(1, 2, 2)
plt.plot(history1.history['loss'] + history2.history['loss'])
plt.plot(history1.history['val_loss'] + history2.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.grid(True)
plt.tight_layout()
plt.savefig('training_history_final.png')
plt.close()
plot_training_history(history1, history2)
# Print all best accuracies achieved
print("\nTop validation accuracies achieved:")
for i, acc in enumerate(sorted(ensemble_callback.best_accs, reverse=True)):
print(f"Model {i+1}: {acc:.2%}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment