-
-
Save helderc/112562d6090cbd535d0c26bbfd952f77 to your computer and use it in GitHub Desktop.
How to do data augmentation on a keras HDF5Matrix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Another way, note this one will load the whole array into memory .""" | |
from keras.preprocessing.image import ImageDataGenerator | |
import h5py | |
from keras.utils.io_utils import HDF5Matrix | |
seed=0 | |
batch_size=32 | |
# we create two instances with the same arguments | |
data_gen_args = dict( | |
rotation_range=90., | |
width_shift_range=0.05, | |
height_shift_range=0.05, | |
zoom_range=0.2, | |
channel_shift_range=0.005, | |
horizontal_flip=True, | |
vertical_flip=True, | |
fill_mode='constant', | |
data_format="channels_last", | |
) | |
image_datagen = ImageDataGenerator(**data_gen_args) | |
mask_datagen = ImageDataGenerator(**data_gen_args) | |
X_train = HDF5Matrix(os.path.join(out_dir, 'train_X_3band.h5'), 'X') | |
y_train = HDF5Matrix(os.path.join(out_dir, 'train_y_3class.h5'), 'y') | |
image_generator = image_datagen.flow( | |
X_train, None, | |
seed=seed, | |
batch_size=batch_size, | |
) | |
mask_generator = mask_datagen.flow( | |
y_train, None, | |
seed=seed, | |
batch_size=batch_size, | |
) | |
# combine generators into one which yields image and masks | |
train_generator = zip(image_generator, mask_generator) | |
train_generator | |
X, y = next(train_generator) | |
X.shape, y.shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""How to do data augmentation on a keras HDF5Matrix""" | |
from keras.utils.io_utils import HDF5Matrix | |
class AugumentedHDF5Matrix(HDF5Matrix): | |
"""Wraps HDF5Matrixs with image augumentation.""" | |
def __init__(self, image_datagen, seed, *args, **kwargs): | |
self.image_datagen = image_datagen | |
self.seed = seed | |
self.i = 0 | |
super().__init__(*args, **kwargs) | |
def __getitem__(self, key): | |
x = super().__getitem__(key) | |
self.i += 1 | |
if len(x.shape) == 3: | |
return self.image_datagen.random_transform( | |
x, seed=self.seed + self.i) | |
else: | |
return np.array([ | |
self.image_datagen.random_transform( | |
xx, seed=self.seed + self.i) for xx in x | |
]) | |
# Test | |
from keras.preprocessing.image import ImageDataGenerator | |
import h5py | |
import numpy as np | |
from matplotlib import pyplot as plt | |
# a keras imagedata generator | |
image_datagen = ImageDataGenerator( | |
width_shift_range=0.05, | |
height_shift_range=0.05, | |
zoom_range=0.1, | |
channel_shift_range=0.005, | |
horizontal_flip=True, | |
vertical_flip=True, | |
fill_mode='constant', | |
data_format="channels_last", | |
rescale=1 / 255.0) | |
# test h5 file | |
images = np.random.random((100, 244, 244, 3)) | |
images[:, 20:30, 20:50, :] = 1 | |
images[:, 50:70, 20:30, :] = 0 | |
datapath = "/tmp/testfile5.hdf5" | |
with h5py.File(datapath, "w") as f: | |
dst = f.create_dataset("X", data=images) | |
# Test | |
X = AugumentedHDF5Matrix(image_datagen, 0, datapath, 'X') | |
a = X[0].mean() | |
X = AugumentedHDF5Matrix(image_datagen, 0, datapath, 'X') | |
b = X[0].mean() | |
assert a == b, 'should be repeatable' | |
c = X[0].mean() | |
assert b != c, 'and random' | |
# Should be able to slice | |
X[1:2][0] | |
X[[1, 2]][0] | |
# View | |
for _ in range(5): | |
plt.imshow(X[0]) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment