Skip to content

Instantly share code, notes, and snippets.

@udaylunawat
Last active September 8, 2022 15:37
Show Gist options
  • Save udaylunawat/4463eb8c40d15a67f657dd387f776fcc to your computer and use it in GitHub Desktop.
Save udaylunawat/4463eb8c40d15a67f657dd387f776fcc to your computer and use it in GitHub Desktop.
Multi Class Image Classification using keras tf.keras.utils.image_dataset_from_directory
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import gc
import subprocess
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
import tensorflow as tf
import tensorflow_addons as tfa
# speed improvements
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
def get_earlystopper(args):
args = args.callback_config
earlystopper = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=args.early_patience,
verbose=0,
mode='auto',
restore_best_weights=True)
return earlystopper
def get_reduce_lr_on_plateau(args):
args = args.callback_config
reduce_lr_on_plateau = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=args.rlrp_factor,
patience=args.rlrp_patience,
min_lr=args.min_lr,
)
return reduce_lr_on_plateau
def get_optimizer(config):
if config.train_config.optimizer == "adam":
opt = optimizers.Adam(learning_rate=config.train_config.lr)
elif config.train_config.optimizer == "rms":
opt = optimizers.RMSprop(learning_rate=config.train_config.lr,
rho=0.9,
epsilon=1e-08,
decay=0.0)
elif config.train_config.optimizer == "sgd":
opt = optimizers.SGD(learning_rate=config.train_config.lr)
elif config.train_config.optimizer == "adamax":
opt = optimizers.Adamax(learning_rate=config.train_config.lr)
return opt
def get_model_weights_ds(train_ds):
class_weights = class_weight.compute_class_weight(
class_weight="balanced",
classes=np.unique(train_ds.class_names),
y=train_ds.class_names,
)
train_class_weights = dict(enumerate(class_weights))
return train_class_weights
def configure_for_performance(ds, config):
ds = ds.cache()
ds = ds.shuffle(buffer_size=1000)
# ds = ds.batch(config.dataset_config.batch_size)
ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
return ds
def get_tfds_from_dir(config):
IMAGE_SIZE = (config.dataset_config.image_width,
config.dataset_config.image_width)
train_ds = tf.keras.utils.image_dataset_from_directory(
"data/4_tfds_dataset/train",
labels='inferred',
label_mode='categorical',
color_mode='rgb',
batch_size=config.dataset_config.batch_size,
image_size=IMAGE_SIZE,
shuffle=True,
seed=config.seed,
# subset='training'
)
val_ds = tf.keras.utils.image_dataset_from_directory(
"data/4_tfds_dataset/val",
labels='inferred',
label_mode='categorical',
color_mode='rgb',
batch_size=config.dataset_config.batch_size,
image_size=IMAGE_SIZE,
shuffle=True,
seed=config.seed,
# subset='validation'
)
test_ds = tf.keras.utils.image_dataset_from_directory(
"data/4_tfds_dataset/test",
labels='inferred',
label_mode='categorical',
color_mode='rgb',
batch_size=config.dataset_config.batch_size,
image_size=IMAGE_SIZE,
shuffle=False,
seed=config.seed,
# subset='validation'
)
return train_ds, val_ds, test_ds
def train(config, train_dataset, val_dataset, labels):
tf.keras.backend.clear_session()
model = get_model(config)
model.summary()
config.train_config.metrics.append(
tfa.metrics.F1Score(num_classes=config.dataset_config.num_classes,
average="macro",
threshold=0.5))
class_weights = None
class_weights = get_model_weights_ds(train_dataset)
optimizer = get_optimizer(config)
# speed improvements
optimizer = mixed_precision.LossScaleOptimizer(optimizer)
# Compile the model
model.compile(
optimizer=optimizer,
loss=config.train_config.loss,
metrics=config.train_config.metrics,
)
earlystopper = get_earlystopper(config)
reduce_lr = get_reduce_lr_on_plateau(config)
callbacks = [wandbcallback, earlystopper, reduce_lr,]
verbose=1
train_dataset = configure_for_performance(train_dataset, config)
val_dataset = configure_for_performance(val_dataset, config)
history = model.fit(
train_dataset,
epochs=config.train_config.epochs,
validation_data=val_dataset,
callbacks=callbacks,
class_weight=class_weights,
workers=-1,
verbose=verbose,
)
return model, history
def main(_):
seed_everything(config.seed)
train_dataset, val_dataset, test_dataset = get_tfds_from_dir(config)
labels = ['class1', 'class2', 'class3']
num_classes = 3
model, history = train(config, train_dataset, val_dataset, labels)
evaluate(config, model, history, test_dataset, labels)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment