Skip to content

Instantly share code, notes, and snippets.

@karolzak
Last active January 16, 2019 22:27
Show Gist options
  • Save karolzak/f8bf0eae07a83939c73cdf0afc43429c to your computer and use it in GitHub Desktop.
Save karolzak/f8bf0eae07a83939c73cdf0afc43429c to your computer and use it in GitHub Desktop.
image augmentation for semantic segmentation models
from keras.preprocessing.image import ImageDataGenerator
# Runtime data augmentation
def get_augmented(
X_train,
Y_train,
X_val,
Y_val,
batch_size=32,
seed=0,
data_gen_args = dict(
#rotation_range=180.,
#width_shift_range=0.1,
#height_shift_range=0.1,
#shear_range=0.2,
#zoom_range=0.2,
horizontal_flip=True,
vertical_flip=True,
fill_mode='constant'
)):
# Train data, provide the same seed and keyword arguments to the fit and flow methods
X_datagen = ImageDataGenerator(**data_gen_args)
Y_datagen = ImageDataGenerator(**data_gen_args)
X_datagen.fit(X_train, augment=True, seed=seed)
Y_datagen.fit(Y_train, augment=True, seed=seed)
X_train_augmented = X_datagen.flow(X_train, batch_size=batch_size, shuffle=True, seed=seed)
Y_train_augmented = Y_datagen.flow(Y_train, batch_size=batch_size, shuffle=True, seed=seed)
# Validation data, no data augmentation, but we create a generator anyway
X_datagen_val = ImageDataGenerator()
Y_datagen_val = ImageDataGenerator()
X_datagen_val.fit(X_val, augment=True, seed=seed)
Y_datagen_val.fit(Y_val, augment=True, seed=seed)
X_val_augmented = X_datagen_val.flow(X_val, batch_size=batch_size, shuffle=True, seed=seed)
Y_val_augmented = Y_datagen_val.flow(Y_val, batch_size=batch_size, shuffle=True, seed=seed)
# combine generators into one which yields image and masks
train_generator = zip(X_train_augmented, Y_train_augmented)
val_generator = zip(X_val_augmented, Y_val_augmented)
return train_generator, val_generator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment