Last active
January 16, 2019 22:27
-
-
Save karolzak/f8bf0eae07a83939c73cdf0afc43429c to your computer and use it in GitHub Desktop.
image augmentation for semantic segmentation models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from keras.preprocessing.image import ImageDataGenerator | |
| # Runtime data augmentation | |
| def get_augmented( | |
| X_train, | |
| Y_train, | |
| X_val, | |
| Y_val, | |
| batch_size=32, | |
| seed=0, | |
| data_gen_args = dict( | |
| #rotation_range=180., | |
| #width_shift_range=0.1, | |
| #height_shift_range=0.1, | |
| #shear_range=0.2, | |
| #zoom_range=0.2, | |
| horizontal_flip=True, | |
| vertical_flip=True, | |
| fill_mode='constant' | |
| )): | |
| # Train data, provide the same seed and keyword arguments to the fit and flow methods | |
| X_datagen = ImageDataGenerator(**data_gen_args) | |
| Y_datagen = ImageDataGenerator(**data_gen_args) | |
| X_datagen.fit(X_train, augment=True, seed=seed) | |
| Y_datagen.fit(Y_train, augment=True, seed=seed) | |
| X_train_augmented = X_datagen.flow(X_train, batch_size=batch_size, shuffle=True, seed=seed) | |
| Y_train_augmented = Y_datagen.flow(Y_train, batch_size=batch_size, shuffle=True, seed=seed) | |
| # Validation data, no data augmentation, but we create a generator anyway | |
| X_datagen_val = ImageDataGenerator() | |
| Y_datagen_val = ImageDataGenerator() | |
| X_datagen_val.fit(X_val, augment=True, seed=seed) | |
| Y_datagen_val.fit(Y_val, augment=True, seed=seed) | |
| X_val_augmented = X_datagen_val.flow(X_val, batch_size=batch_size, shuffle=True, seed=seed) | |
| Y_val_augmented = Y_datagen_val.flow(Y_val, batch_size=batch_size, shuffle=True, seed=seed) | |
| # combine generators into one which yields image and masks | |
| train_generator = zip(X_train_augmented, Y_train_augmented) | |
| val_generator = zip(X_val_augmented, Y_val_augmented) | |
| return train_generator, val_generator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment