Skip to content

Instantly share code, notes, and snippets.

@devxpy
Last active December 24, 2019 08:02
Show Gist options
  • Save devxpy/a73744bab1b77a79bcad553cbe589493 to your computer and use it in GitHub Desktop.
Save devxpy/a73744bab1b77a79bcad553cbe589493 to your computer and use it in GitHub Desktop.
# Generates 4 sets of images
# - original images
# - images with rotation
# - images with horizontal flip
# - images with vertical flip
train_gen = PersonDataGenerator(
train_df,
batch_size=32,
aug_list=[
ImageDataGenerator(rotation_range=45),
ImageDataGenerator(horizontal_flip=True),
ImageDataGenerator(vertical_flip=True),
],
incl_orig=True, # Whether to include original images
)
# Generates 1 set of augmented images, not including the original images
train_gen = PersonDataGenerator(
train_df,
batch_size=32,
aug_list=[
ImageDataGenerator(rotation_range=45),
],
incl_orig=False,
)
class PersonDataGenerator(keras.utils.Sequence):
def __init__(self, df, batch_size=32, shuffle=True, aug_list=[], incl_orig=True):
self.df = df
self.batch_size=batch_size
self.shuffle = shuffle
self.on_epoch_end()
self.aug_list = aug_list
self.incl_orig = incl_orig
self.orig_len = int(np.floor(self.df.shape[0] / self.batch_size))
def __len__(self):
if self.incl_orig:
delta = 1
else:
delta = 0
return self.orig_len * (len(self.aug_list) + delta)
def __getitem__(self, index):
if not self.incl_orig :
index += self.orig_len - 1
if index > self.orig_len - 1:
aug = self.aug_list[index // self.orig_len - 1]
index %= self.orig_len
else:
aug = None
batch_slice = slice(index * self.batch_size, (index + 1) * self.batch_size)
items = self.df.iloc[batch_slice]
images = np.stack([cv2.imread(item["image_path"]) for _, item in items.iterrows()])
if aug is not None:
images = aug.flow(images, shuffle=False).next()
target = {
"gender_output": items[_gender_cols_].values,
"image_quality_output": items[_imagequality_cols_].values,
"age_output": items[_age_cols_].values,
"weight_output": items[_weight_cols_].values,
"bag_output": items[_carryingbag_cols_].values,
"pose_output": items[_bodypose_cols_].values,
"footwear_output": items[_footwear_cols_].values,
"emotion_output": items[_emotion_cols_].values,
}
return images, target
def on_epoch_end(self):
"""Updates indexes after each epoch"""
if self.shuffle == True:
self.df = self.df.sample(frac=1).reset_index(drop=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment