Last active
April 30, 2023 07:03
-
-
Save sayakpaul/e0024bae08afcd3d75b6d52fda191025 to your computer and use it in GitHub Desktop.
Example of incorporating RandAugment in a tf.data pipeline for image classification.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from imgaug import augmenters as iaa | |
import imgaug as ia | |
ia.seed(4) | |
import tensorflow as tf | |
tf.random.set_seed(666) | |
aug = iaa.RandAugment(n=2, m=9) | |
BATCH_SIZE = 224 | |
def augment(images): | |
return aug(images=images.numpy()) | |
# Function to read the TFRecords, segregate the images and labels | |
def read_tfrecord(example): | |
features = { | |
"image": tf.io.FixedLenFeature([], tf.string), | |
"class": tf.io.FixedLenFeature([], tf.int64) | |
} | |
example = tf.io.parse_single_example(example, features) | |
image = tf.image.decode_jpeg(example['image'], channels=3) | |
class_label = tf.cast(example['class'], tf.int32) | |
return (image, class_label) | |
# Load the TFRecords and create tf.data.Dataset | |
def load_dataset(filenames): | |
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) | |
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO) | |
return dataset | |
# Batch, shuffle, and repeat the dataset and pre-fetch it | |
# well before the current epoch ends | |
def batch_dataset(filenames, batch_size=BATCH_SIZE, train=True): | |
opt = tf.data.Options() | |
opt.experimental_deterministic = False | |
dataset = load_dataset(filenames) | |
if train: | |
dataset = dataset.repeat() | |
dataset = dataset.shuffle(BATCH_SIZE*10) | |
dataset = dataset.batch(batch_size) | |
dataset = dataset.map(lambda x, y: (tf.py_function(augment, [x], [tf.float32]), y), | |
num_parallel_calls=AUTO) | |
dataset = dataset.map(lambda x, y: (tf.squeeze(x), y), | |
num_parallel_calls=AUTO) | |
else: | |
dataset = dataset.batch(batch_size) | |
dataset = dataset.prefetch(AUTO) | |
dataset = dataset.with_options(opt) | |
return dataset | |
train_pattern = "train_tfr_224/*.tfrec" | |
train_filenames = tf.io.gfile.glob(train_pattern) | |
val_pattern = "val_tfr_224/*.tfrec" | |
val_filenames = tf.io.gfile.glob(val_pattern) | |
training_ds = batch_dataset(train_filenames) | |
validation_ds = batch_dataset(val_filenames, train=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks to @DarshanDeshpande for catching a one-off bug.