Last active
December 24, 2019 08:02
-
-
Save devxpy/a73744bab1b77a79bcad553cbe589493 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Generates 4 sets of images | |
# - original images | |
# - images with rotation | |
# - images with horizontal flip | |
# - images with vertical flip | |
train_gen = PersonDataGenerator( | |
train_df, | |
batch_size=32, | |
aug_list=[ | |
ImageDataGenerator(rotation_range=45), | |
ImageDataGenerator(horizontal_flip=True), | |
ImageDataGenerator(vertical_flip=True), | |
], | |
incl_orig=True, # Whether to include original images | |
) | |
# Generates 1 set of augmented images, not including the original images | |
train_gen = PersonDataGenerator( | |
train_df, | |
batch_size=32, | |
aug_list=[ | |
ImageDataGenerator(rotation_range=45), | |
], | |
incl_orig=False, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PersonDataGenerator(keras.utils.Sequence): | |
def __init__(self, df, batch_size=32, shuffle=True, aug_list=[], incl_orig=True): | |
self.df = df | |
self.batch_size=batch_size | |
self.shuffle = shuffle | |
self.on_epoch_end() | |
self.aug_list = aug_list | |
self.incl_orig = incl_orig | |
self.orig_len = int(np.floor(self.df.shape[0] / self.batch_size)) | |
def __len__(self): | |
if self.incl_orig: | |
delta = 1 | |
else: | |
delta = 0 | |
return self.orig_len * (len(self.aug_list) + delta) | |
def __getitem__(self, index): | |
if not self.incl_orig : | |
index += self.orig_len - 1 | |
if index > self.orig_len - 1: | |
aug = self.aug_list[index // self.orig_len - 1] | |
index %= self.orig_len | |
else: | |
aug = None | |
batch_slice = slice(index * self.batch_size, (index + 1) * self.batch_size) | |
items = self.df.iloc[batch_slice] | |
images = np.stack([cv2.imread(item["image_path"]) for _, item in items.iterrows()]) | |
if aug is not None: | |
images = aug.flow(images, shuffle=False).next() | |
target = { | |
"gender_output": items[_gender_cols_].values, | |
"image_quality_output": items[_imagequality_cols_].values, | |
"age_output": items[_age_cols_].values, | |
"weight_output": items[_weight_cols_].values, | |
"bag_output": items[_carryingbag_cols_].values, | |
"pose_output": items[_bodypose_cols_].values, | |
"footwear_output": items[_footwear_cols_].values, | |
"emotion_output": items[_emotion_cols_].values, | |
} | |
return images, target | |
def on_epoch_end(self): | |
"""Updates indexes after each epoch""" | |
if self.shuffle == True: | |
self.df = self.df.sample(frac=1).reset_index(drop=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment