Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save udaylunawat/f17d57a674607c5d2f0382b4de939c3a to your computer and use it in GitHub Desktop.
Save udaylunawat/f17d57a674607c5d2f0382b4de939c3a to your computer and use it in GitHub Desktop.
Multi Class Image Classification using keras ImageDataGenerator
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import gc
import subprocess
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
import tensorflow as tf
import tensorflow_addons as tfa
# speed improvements
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
def get_earlystopper(args):
args = args.callback_config
earlystopper = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=args.early_patience,
verbose=0,
mode='auto',
restore_best_weights=True)
return earlystopper
def get_reduce_lr_on_plateau(args):
args = args.callback_config
reduce_lr_on_plateau = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=args.rlrp_factor,
patience=args.rlrp_patience,
min_lr=args.min_lr,
)
return reduce_lr_on_plateau
def get_optimizer(config):
if config.train_config.optimizer == "adam":
opt = optimizers.Adam(learning_rate=config.train_config.lr)
elif config.train_config.optimizer == "rms":
opt = optimizers.RMSprop(learning_rate=config.train_config.lr,
rho=0.9,
epsilon=1e-08,
decay=0.0)
elif config.train_config.optimizer == "sgd":
opt = optimizers.SGD(learning_rate=config.train_config.lr)
elif config.train_config.optimizer == "adamax":
opt = optimizers.Adamax(learning_rate=config.train_config.lr)
return opt
def get_model_weights_gen(train_generator):
class_weights = class_weight.compute_class_weight(
class_weight="balanced",
classes=np.unique(train_generator.classes),
y=train_generator.classes,
)
train_class_weights = dict(enumerate(class_weights))
return train_class_weights
def get_generators(config):
IMAGE_SIZE = (config.dataset_config.image_width,
config.dataset_config.image_width)
if config.train_config.use_augmentations:
print("\n\nAugmentation is True! rescale=1./255")
train_datagen = ImageDataGenerator(
horizontal_flip=True,
vertical_flip=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
zoom_range=[0.5, 1.5],
rescale=1.0 / 255,
) # preprocessing_function=scalar
elif not config.train_config.use_augmentations:
print("No Augmentation!")
train_datagen = ImageDataGenerator(rescale=1.0 / 255)
else:
print("Error in config.augment. Stop Training!")
train_dataset = train_datagen.flow_from_directory(
"data/4_tfds_dataset/train",
target_size=IMAGE_SIZE,
batch_size=config.dataset_config.batch_size,
shuffle=True,
color_mode="rgb",
class_mode="categorical",
)
test_datagen = ImageDataGenerator(rescale=1.0 /
255) # preprocessing_function=scalar
val_dataset = test_datagen.flow_from_directory(
"data/4_tfds_dataset/val",
shuffle=True,
color_mode="rgb",
target_size=IMAGE_SIZE,
batch_size=config.dataset_config.batch_size,
class_mode="categorical",
)
test_generator = test_datagen.flow_from_directory(
"data/4_tfds_dataset/test",
batch_size=config.dataset_config.batch_size,
seed=config.seed,
color_mode="rgb",
shuffle=False,
class_mode="categorical",
target_size=IMAGE_SIZE,
)
return train_dataset, val_dataset, test_generator
def train(config, train_dataset, val_dataset, labels):
tf.keras.backend.clear_session()
model = get_model(config)
model.summary()
config.train_config.metrics.append(
tfa.metrics.F1Score(num_classes=config.dataset_config.num_classes,
average="macro",
threshold=0.5))
class_weights = None
class_weights = get_model_weights_gen(train_dataset)
optimizer = get_optimizer(config)
# speed improvements
optimizer = mixed_precision.LossScaleOptimizer(optimizer)
# Compile the model
model.compile(
optimizer=optimizer,
loss=config.train_config.loss,
metrics=config.train_config.metrics,
)
earlystopper = get_earlystopper(config)
reduce_lr = get_reduce_lr_on_plateau(config)
callbacks = [wandbcallback, earlystopper, reduce_lr,]
verbose=1
history = model.fit(
train_dataset,
epochs=config.train_config.epochs,
validation_data=val_dataset,
callbacks=callbacks,
class_weight=class_weights,
workers=-1,
verbose=verbose,
)
return model, history
def main(_):
seed_everything(config.seed)
train_dataset, val_dataset, test_dataset = get_generators(config)
labels = ['class1', 'class2', 'class3']
num_classes = 3
model, history = train(config, train_dataset, val_dataset, labels)
evaluate(config, model, history, test_dataset, labels)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment