Skip to content

Instantly share code, notes, and snippets.

@tmwatchanan
Created January 17, 2019 15:31
Show Gist options
  • Select an option

  • Save tmwatchanan/bd8724730fc6aa033d47dc7bc3df65c2 to your computer and use it in GitHub Desktop.

Select an option

Save tmwatchanan/bd8724730fc6aa033d47dc7bc3df65c2 to your computer and use it in GitHub Desktop.
Keras -- creating a generator for training set
def trainGenerator(batch_size,
train_path,
image_folder,
mask_folder,
aug_dict,
image_color_mode="grayscale",
mask_color_mode="grayscale",
image_save_prefix="image",
mask_save_prefix="mask",
flag_multi_class=False,
num_class=2,
save_to_dir=None,
target_size=(256, 256),
seed=1):
'''
can generate image and mask at the same time
use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
if you want to visualize the results of generator, set save_to_dir = "your path"
'''
image_datagen = ImageDataGenerator(**aug_dict)
mask_datagen = ImageDataGenerator(**aug_dict)
image_generator = image_datagen.flow_from_directory(
train_path,
classes=[image_folder],
class_mode=None,
color_mode=image_color_mode,
target_size=target_size,
batch_size=batch_size,
save_to_dir=save_to_dir,
save_prefix=image_save_prefix,
seed=seed)
mask_generator = mask_datagen.flow_from_directory(
train_path,
classes=[mask_folder],
class_mode=None,
color_mode=mask_color_mode,
target_size=target_size,
batch_size=batch_size,
save_to_dir=save_to_dir,
save_prefix=mask_save_prefix,
seed=seed)
train_generator = zip(image_generator, mask_generator)
for (img, mask) in train_generator:
# print(np.max(mask))
# print(mask[0, 0, 0], mask[0, 0, 1], mask[0, 0, 2])
# print(mask[64, 64, 0], mask[64, 64, 1], mask[64, 64, 2])
img, mask = adjustData(img, mask, flag_multi_class, num_class)
yield (img, mask)
@tmwatchanan
Copy link
Copy Markdown
Author

An example of data augmentation dictionary

data_gen_args = dict(
    rotation_range=0.2,
    width_shift_range=0.05,
    height_shift_range=0.05,
    shear_range=0.05,
    zoom_range=0.05,
    horizontal_flip=True,
    fill_mode='nearest')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment