Skip to content

Instantly share code, notes, and snippets.

@helderc
Forked from wassname/augumented_hdf5_matrix.py
Created August 10, 2018 11:51
Show Gist options
  • Save helderc/112562d6090cbd535d0c26bbfd952f77 to your computer and use it in GitHub Desktop.
Save helderc/112562d6090cbd535d0c26bbfd952f77 to your computer and use it in GitHub Desktop.
How to do data augmentation on a keras HDF5Matrix
"""Another way, note this one will load the whole array into memory ."""
from keras.preprocessing.image import ImageDataGenerator
import h5py
from keras.utils.io_utils import HDF5Matrix
seed=0
batch_size=32
# we create two instances with the same arguments
data_gen_args = dict(
rotation_range=90.,
width_shift_range=0.05,
height_shift_range=0.05,
zoom_range=0.2,
channel_shift_range=0.005,
horizontal_flip=True,
vertical_flip=True,
fill_mode='constant',
data_format="channels_last",
)
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)
X_train = HDF5Matrix(os.path.join(out_dir, 'train_X_3band.h5'), 'X')
y_train = HDF5Matrix(os.path.join(out_dir, 'train_y_3class.h5'), 'y')
image_generator = image_datagen.flow(
X_train, None,
seed=seed,
batch_size=batch_size,
)
mask_generator = mask_datagen.flow(
y_train, None,
seed=seed,
batch_size=batch_size,
)
# combine generators into one which yields image and masks
train_generator = zip(image_generator, mask_generator)
train_generator
X, y = next(train_generator)
X.shape, y.shape
"""How to do data augmentation on a keras HDF5Matrix"""
from keras.utils.io_utils import HDF5Matrix
class AugumentedHDF5Matrix(HDF5Matrix):
"""Wraps HDF5Matrixs with image augumentation."""
def __init__(self, image_datagen, seed, *args, **kwargs):
self.image_datagen = image_datagen
self.seed = seed
self.i = 0
super().__init__(*args, **kwargs)
def __getitem__(self, key):
x = super().__getitem__(key)
self.i += 1
if len(x.shape) == 3:
return self.image_datagen.random_transform(
x, seed=self.seed + self.i)
else:
return np.array([
self.image_datagen.random_transform(
xx, seed=self.seed + self.i) for xx in x
])
# Test
from keras.preprocessing.image import ImageDataGenerator
import h5py
import numpy as np
from matplotlib import pyplot as plt
# a keras imagedata generator
image_datagen = ImageDataGenerator(
width_shift_range=0.05,
height_shift_range=0.05,
zoom_range=0.1,
channel_shift_range=0.005,
horizontal_flip=True,
vertical_flip=True,
fill_mode='constant',
data_format="channels_last",
rescale=1 / 255.0)
# test h5 file
images = np.random.random((100, 244, 244, 3))
images[:, 20:30, 20:50, :] = 1
images[:, 50:70, 20:30, :] = 0
datapath = "/tmp/testfile5.hdf5"
with h5py.File(datapath, "w") as f:
dst = f.create_dataset("X", data=images)
# Test
X = AugumentedHDF5Matrix(image_datagen, 0, datapath, 'X')
a = X[0].mean()
X = AugumentedHDF5Matrix(image_datagen, 0, datapath, 'X')
b = X[0].mean()
assert a == b, 'should be repeatable'
c = X[0].mean()
assert b != c, 'and random'
# Should be able to slice
X[1:2][0]
X[[1, 2]][0]
# View
for _ in range(5):
plt.imshow(X[0])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment