Created
January 17, 2019 15:31
-
-
Save tmwatchanan/bd8724730fc6aa033d47dc7bc3df65c2 to your computer and use it in GitHub Desktop.
Keras -- creating a generator for training set
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def trainGenerator(batch_size, | |
| train_path, | |
| image_folder, | |
| mask_folder, | |
| aug_dict, | |
| image_color_mode="grayscale", | |
| mask_color_mode="grayscale", | |
| image_save_prefix="image", | |
| mask_save_prefix="mask", | |
| flag_multi_class=False, | |
| num_class=2, | |
| save_to_dir=None, | |
| target_size=(256, 256), | |
| seed=1): | |
| ''' | |
| can generate image and mask at the same time | |
| use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same | |
| if you want to visualize the results of generator, set save_to_dir = "your path" | |
| ''' | |
| image_datagen = ImageDataGenerator(**aug_dict) | |
| mask_datagen = ImageDataGenerator(**aug_dict) | |
| image_generator = image_datagen.flow_from_directory( | |
| train_path, | |
| classes=[image_folder], | |
| class_mode=None, | |
| color_mode=image_color_mode, | |
| target_size=target_size, | |
| batch_size=batch_size, | |
| save_to_dir=save_to_dir, | |
| save_prefix=image_save_prefix, | |
| seed=seed) | |
| mask_generator = mask_datagen.flow_from_directory( | |
| train_path, | |
| classes=[mask_folder], | |
| class_mode=None, | |
| color_mode=mask_color_mode, | |
| target_size=target_size, | |
| batch_size=batch_size, | |
| save_to_dir=save_to_dir, | |
| save_prefix=mask_save_prefix, | |
| seed=seed) | |
| train_generator = zip(image_generator, mask_generator) | |
| for (img, mask) in train_generator: | |
| # print(np.max(mask)) | |
| # print(mask[0, 0, 0], mask[0, 0, 1], mask[0, 0, 2]) | |
| # print(mask[64, 64, 0], mask[64, 64, 1], mask[64, 64, 2]) | |
| img, mask = adjustData(img, mask, flag_multi_class, num_class) | |
| yield (img, mask) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
An example of data augmentation dictionary