Last active
October 9, 2020 13:54
-
-
Save danyashorokh/410f1ba2b445beec25fea8ac38ce5f51 to your computer and use it in GitHub Desktop.
[KERAS] CustomImageDataGenerator images and masks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras.utils import Sequence | |
import imgaug as ia | |
import imgaug.augmenters as iaa | |
from imgaug.augmentables.segmaps import SegmentationMapsOnImage | |
ia.seed(1) | |
train_imgs = np.load('train_imgs.npy') | |
train_masks = np.load('train_masks.npy') | |
val_imgs = np.load('val_imgs.npy') | |
val_masks = np.load('val_masks.npy') | |
class DataGeneratorImageMask(Sequence): | |
def __init__(self, images, masks, batch_size=32, input_size=(256, 256, 3), num_classes=None, shuffle=True, transform=None): | |
self.batch_size = batch_size | |
self.images = images | |
self.masks = masks | |
self.input_size = input_size | |
self.transform = transform | |
self.indices = list(range(self.images.shape[0])) | |
self.num_classes = num_classes | |
self.shuffle = shuffle | |
self.on_epoch_end() | |
def __len__(self): | |
return len(self.indices) // self.batch_size | |
def __getitem__(self, index): | |
index = self.index[index * self.batch_size:(index + 1) * self.batch_size] | |
batch = [self.indices[k] for k in index] | |
X, y = self.__get_data(batch) | |
return X, y | |
def on_epoch_end(self): | |
self.index = np.arange(len(self.indices)) | |
if self.shuffle == True: | |
np.random.shuffle(self.index) | |
def __get_data(self, batch): | |
X = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], self.input_size[2])) | |
y = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], 1)) | |
for i, idx in enumerate(batch): | |
img = self.images[idx] | |
mask = self.masks[idx] | |
if self.transform: | |
segmap = SegmentationMapsOnImage(mask, shape=img.shape) | |
img, segmap_aug = self.transform(image=img, segmentation_maps=segmap) # ver 1 | |
# # ver 2 | |
# aug = aug.to_deterministic() | |
# image_aug = aug_det.augment_image( img ) | |
# segmap_aug = aug_det.augment_segmentation_maps(segmap) | |
mask = segmap_aug.get_arr() | |
X[i] = img | |
y[i] = mask | |
X = X.astype(np.float32) / 255.0 | |
y = y.astype(np.float32) | |
return X, y | |
seq = iaa.SomeOf((0, None), [ | |
# iaa.CropAndPad(percent=(-0.25, 0.25)) | |
# iaa.HorizontalFlip(1), # horizontally flip the images | |
# iaa.VerticalFlip(1), # vertical flip the images | |
iaa.Rot90([1]), # 90 | |
# iaa.Rot90([2]), # 180 | |
# iaa.Rot90([3]), # 270 | |
# iaa.Clouds(), | |
# iaa.Fog(), | |
# iaa.Crop(px=(5, 16)), # 128 crop images from each side by 0 to 16px (randomly chosen)\n", | |
iaa.Affine( | |
scale={"x": (1, 1.4), "y": (1, 1.4)}, | |
#translate_percent={\"x\": (-0.2, 0.2), \"y\": (-0.2, 0.2)}, | |
rotate=(-5, 5), | |
shear=(-10, 10) | |
), | |
iaa.AddToHue((-20, 20)), | |
iaa.AdditiveGaussianNoise(scale=0.02 * 255), | |
iaa.GammaContrast((0.8, 1.6)), | |
]) | |
train_batch_size = 64 | |
val_batch_size = 64 | |
train_generator = DataGeneratorImageMask(train_imgs, train_masks, batch_size=train_batch_size, shuffle=True, transform=seq) | |
val_generator = DataGeneratorImageMask(val_imgs, val_masks, batch_size=val_batch_size, shuffle=False, transform=None) | |
model = some_model() | |
path = f'snapshots/' | |
if not os.path.exists(path): | |
os.makedirs(path) | |
earlystopper = EarlyStopping(patience=15, verbose=1) | |
filepath = path + 'e{epoch:02d}_b64_val{val_loss:.4f}_iou{iou:.4f}_iouv{val_metric:.4f}.h5' | |
callbacks = [ | |
ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, min_lr=1e-9, epsilon=0.00001, verbose=1, mode='min'), | |
# EarlyStopping(monitor='val_loss', patience=patience, verbose=0), | |
ModelCheckpoint(filepath, monitor='val_loss', save_best_only=True, verbose=0), | |
] | |
results = model.fit_generator(train_generator, | |
epochs=100, | |
steps_per_epoch=train_imgs.shape[0] // train_batch_size, | |
validation_steps=val_imgs.shape[0] // val_batch_size, | |
validation_data=val_generator, | |
verbose=1, callbacks=callbacks | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment