Last active
October 11, 2020 23:32
-
-
Save wassname/74f02bc9134897e3fe4e60784f5aaa15 to your computer and use it in GitHub Desktop.
How to do data augmentation on a keras HDF5Matrix
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Another way, note this one will load the whole array into memory .""" | |
from keras.preprocessing.image import ImageDataGenerator | |
import h5py | |
from keras.utils.io_utils import HDF5Matrix | |
seed=0 | |
batch_size=32 | |
# we create two instances with the same arguments | |
data_gen_args = dict( | |
rotation_range=90., | |
width_shift_range=0.05, | |
height_shift_range=0.05, | |
zoom_range=0.2, | |
channel_shift_range=0.005, | |
horizontal_flip=True, | |
vertical_flip=True, | |
fill_mode='constant', | |
data_format="channels_last", | |
) | |
image_datagen = ImageDataGenerator(**data_gen_args) | |
mask_datagen = ImageDataGenerator(**data_gen_args) | |
X_train = HDF5Matrix(os.path.join(out_dir, 'train_X_3band.h5'), 'X') | |
y_train = HDF5Matrix(os.path.join(out_dir, 'train_y_3class.h5'), 'y') | |
image_generator = image_datagen.flow( | |
X_train, None, | |
seed=seed, | |
batch_size=batch_size, | |
) | |
mask_generator = mask_datagen.flow( | |
y_train, None, | |
seed=seed, | |
batch_size=batch_size, | |
) | |
# combine generators into one which yields image and masks | |
train_generator = zip(image_generator, mask_generator) | |
train_generator | |
X, y = next(train_generator) | |
X.shape, y.shape |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""How to do data augmentation on a keras HDF5Matrix""" | |
from keras.utils.io_utils import HDF5Matrix | |
class AugumentedHDF5Matrix(HDF5Matrix): | |
"""Wraps HDF5Matrixs with image augumentation.""" | |
def __init__(self, image_datagen, seed, *args, **kwargs): | |
self.image_datagen = image_datagen | |
self.seed = seed | |
self.i = 0 | |
super().__init__(*args, **kwargs) | |
def __getitem__(self, key): | |
x = super().__getitem__(key) | |
self.i += 1 | |
if len(x.shape) == 3: | |
return self.image_datagen.random_transform( | |
x, seed=self.seed + self.i) | |
else: | |
return np.array([ | |
self.image_datagen.random_transform( | |
xx, seed=self.seed + self.i) for xx in x | |
]) | |
# Test | |
from keras.preprocessing.image import ImageDataGenerator | |
import h5py | |
import numpy as np | |
from matplotlib import pyplot as plt | |
# a keras imagedata generator | |
image_datagen = ImageDataGenerator( | |
width_shift_range=0.05, | |
height_shift_range=0.05, | |
zoom_range=0.1, | |
channel_shift_range=0.005, | |
horizontal_flip=True, | |
vertical_flip=True, | |
fill_mode='constant', | |
data_format="channels_last", | |
rescale=1 / 255.0) | |
# test h5 file | |
images = np.random.random((100, 244, 244, 3)) | |
images[:, 20:30, 20:50, :] = 1 | |
images[:, 50:70, 20:30, :] = 0 | |
datapath = "/tmp/testfile5.hdf5" | |
with h5py.File(datapath, "w") as f: | |
dst = f.create_dataset("X", data=images) | |
# Test | |
X = AugumentedHDF5Matrix(image_datagen, 0, datapath, 'X') | |
a = X[0].mean() | |
X = AugumentedHDF5Matrix(image_datagen, 0, datapath, 'X') | |
b = X[0].mean() | |
assert a == b, 'should be repeatable' | |
c = X[0].mean() | |
assert b != c, 'and random' | |
# Should be able to slice | |
X[1:2][0] | |
X[[1, 2]][0] | |
# View | |
for _ in range(5): | |
plt.imshow(X[0]) | |
plt.show() |
Author
wassname
commented
Aug 19, 2017
Interesting code but I have one question: Is it scalable in the case where my data, inside the HDF5 file, does not fit into the memory?
Probably not. I've moved onto making multiple hdf5 files of ~400mb, then loading the whole lot as a dask array.
as explained in this issue https://github.com/keras-team/keras/issues/2674#issuecomment-218036900 , the data is not loaded into the memory but read from the disk, so it is not neccesary for the hdf5 data to be small enough
Ideally, to work with keras' fit_generator()
function, AugumentedHDF5Matrix
should be implemented as an iterator.
This might still work since you can iterate on the object (because of a legacy behavior of python, as explained here), but I wouldn't count on that...
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment