Last active
October 11, 2020 23:32
-
-
Save wassname/74f02bc9134897e3fe4e60784f5aaa15 to your computer and use it in GitHub Desktop.
How to do data augmentation on a keras HDF5Matrix
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Another way, note this one will load the whole array into memory .""" | |
from keras.preprocessing.image import ImageDataGenerator | |
import h5py | |
from keras.utils.io_utils import HDF5Matrix | |
seed=0 | |
batch_size=32 | |
# we create two instances with the same arguments | |
data_gen_args = dict( | |
rotation_range=90., | |
width_shift_range=0.05, | |
height_shift_range=0.05, | |
zoom_range=0.2, | |
channel_shift_range=0.005, | |
horizontal_flip=True, | |
vertical_flip=True, | |
fill_mode='constant', | |
data_format="channels_last", | |
) | |
image_datagen = ImageDataGenerator(**data_gen_args) | |
mask_datagen = ImageDataGenerator(**data_gen_args) | |
X_train = HDF5Matrix(os.path.join(out_dir, 'train_X_3band.h5'), 'X') | |
y_train = HDF5Matrix(os.path.join(out_dir, 'train_y_3class.h5'), 'y') | |
image_generator = image_datagen.flow( | |
X_train, None, | |
seed=seed, | |
batch_size=batch_size, | |
) | |
mask_generator = mask_datagen.flow( | |
y_train, None, | |
seed=seed, | |
batch_size=batch_size, | |
) | |
# combine generators into one which yields image and masks | |
train_generator = zip(image_generator, mask_generator) | |
train_generator | |
X, y = next(train_generator) | |
X.shape, y.shape |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""How to do data augmentation on a keras HDF5Matrix""" | |
from keras.utils.io_utils import HDF5Matrix | |
class AugumentedHDF5Matrix(HDF5Matrix): | |
"""Wraps HDF5Matrixs with image augumentation.""" | |
def __init__(self, image_datagen, seed, *args, **kwargs): | |
self.image_datagen = image_datagen | |
self.seed = seed | |
self.i = 0 | |
super().__init__(*args, **kwargs) | |
def __getitem__(self, key): | |
x = super().__getitem__(key) | |
self.i += 1 | |
if len(x.shape) == 3: | |
return self.image_datagen.random_transform( | |
x, seed=self.seed + self.i) | |
else: | |
return np.array([ | |
self.image_datagen.random_transform( | |
xx, seed=self.seed + self.i) for xx in x | |
]) | |
# Test | |
from keras.preprocessing.image import ImageDataGenerator | |
import h5py | |
import numpy as np | |
from matplotlib import pyplot as plt | |
# a keras imagedata generator | |
image_datagen = ImageDataGenerator( | |
width_shift_range=0.05, | |
height_shift_range=0.05, | |
zoom_range=0.1, | |
channel_shift_range=0.005, | |
horizontal_flip=True, | |
vertical_flip=True, | |
fill_mode='constant', | |
data_format="channels_last", | |
rescale=1 / 255.0) | |
# test h5 file | |
images = np.random.random((100, 244, 244, 3)) | |
images[:, 20:30, 20:50, :] = 1 | |
images[:, 50:70, 20:30, :] = 0 | |
datapath = "/tmp/testfile5.hdf5" | |
with h5py.File(datapath, "w") as f: | |
dst = f.create_dataset("X", data=images) | |
# Test | |
X = AugumentedHDF5Matrix(image_datagen, 0, datapath, 'X') | |
a = X[0].mean() | |
X = AugumentedHDF5Matrix(image_datagen, 0, datapath, 'X') | |
b = X[0].mean() | |
assert a == b, 'should be repeatable' | |
c = X[0].mean() | |
assert b != c, 'and random' | |
# Should be able to slice | |
X[1:2][0] | |
X[[1, 2]][0] | |
# View | |
for _ in range(5): | |
plt.imshow(X[0]) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Ideally, to work with keras'
fit_generator()
function,AugumentedHDF5Matrix
should be implemented as an iterator.This might still work since you can iterate on the object (because of a legacy behavior of python, as explained here), but I wouldn't count on that...