Created
December 28, 2017 05:46
-
-
Save renexu/859d05fa3df4509b676fd31bd220ec1b to your computer and use it in GitHub Desktop.
Keras HDF5Matrix and fit_generator for huge hdf5 dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
from keras.applications.inception_v3 import InceptionV3 | |
from keras.optimizers import Adam | |
from keras.utils.io_utils import HDF5Matrix | |
class threadsafe_iter: | |
"""Takes an iterator/generator and makes it thread-safe by | |
serializing call to the `next` method of given iterator/generator. | |
""" | |
def __init__(self, it): | |
self.it = it | |
self.lock = threading.Lock() | |
def __iter__(self): | |
return self | |
def __next__(self): | |
with self.lock: | |
return self.it.__next__() | |
def threadsafe_generator(f): | |
"""A decorator that takes a generator function and makes it thread-safe. | |
""" | |
def g(*a, **kw): | |
return threadsafe_iter(f(*a, **kw)) | |
return g | |
@threadsafe_generator | |
def generator(hdf5_file, batch_size): | |
x = HDF5Matrix(hdf5_file, 'x') | |
size = x.end | |
y = HDF5Matrix(hdf5_file, 'y') | |
idx = 0 | |
while True: | |
last_batch = idx + batch_size > size | |
end = idx + batch_size if not last_batch else size | |
yield x[idx:end], y[idx:end] | |
idx = end if not last_batch else 0 | |
def data_statistic(train_dataset, test_dataset): | |
train_x = HDF5Matrix(train_dataset, 'x') | |
test_x = HDF5Matrix(test_dataset, 'x') | |
return train_x.end, test_x.end | |
def build_model(): | |
m = InceptionV3(weights=None) | |
return m | |
if __name__ == '__main__': | |
batch_size = 32 | |
train_dataset = 'train.h5' | |
test_dataset = 'test.h5' | |
model = build_model() | |
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy']) | |
train_generator = generator(train_dataset, batch_size) | |
test_generator = generator(test_dataset, batch_size) | |
nb_train_samples, nb_test_samples = data_statistic(train_dataset, test_dataset) | |
print('train samples: %d, test samples: %d' % (nb_train_samples, nb_test_samples)) | |
model.fit_generator( | |
epochs=10, | |
generator=train_generator, steps_per_epoch=nb_train_samples // batch_size, | |
validation_data=test_generator, validation_steps=nb_test_samples // batch_size, | |
max_queue_size=10, # use a value which can fit batch_size * image_size * max_queue_size in your CPU memory | |
workers=1, # I don't see multi workers can have any performance benefit without multi threading | |
use_multiprocessing=False, # HDF5Matrix cannot support multi-threads | |
shuffle=False) # you cannot shuffle on a HDF5Matrix, so make sure you shuffle the data before save to h5 file |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment