Last active
May 10, 2018 14:39
-
-
Save Mirodil/a318fb03a8f8062b41b5f8e7757150d4 to your computer and use it in GitHub Desktop.
FileListIterator for keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras import backend as K | |
from keras.preprocessing.image import Iterator, load_img, img_to_array | |
class FileListIterator(Iterator): | |
"""Iterator capable of reading images from an array of the filenames. | |
# Arguments | |
filenames: Path to the directory to read images from. | |
Each subdirectory in this directory will be | |
considered to contain images from one class, | |
or alternatively you could specify class subdirectories | |
via the `classes` argument. | |
fileClasses: Associated classes for each file in the file names. | |
It should be the same size as filenames | |
image_data_generator: Instance of `ImageDataGenerator` | |
to use for random transformations and normalization. | |
target_size: tuple of integers, dimensions to resize input images to. | |
color_mode: One of `"rgb"`, `"grayscale"`. Color mode to read images. | |
classes: Optional list of strings, names of subdirectories | |
containing images from each class (e.g. `["dogs", "cats"]`). | |
It will be computed automatically if not set. | |
class_mode: Mode for yielding the targets: | |
`"binary"`: binary targets (if there are only two classes), | |
`"categorical"`: categorical targets, | |
`"sparse"`: integer targets, | |
`"input"`: targets are images identical to input images (mainly | |
used to work with autoencoders), | |
`None`: no targets get yielded (only input images are yielded). | |
batch_size: Integer, size of a batch. | |
shuffle: Boolean, whether to shuffle the data between epochs. | |
seed: Random seed for data shuffling. | |
data_format: String, one of `channels_first`, `channels_last`. | |
save_to_dir: Optional directory where to save the pictures | |
being yielded, in a viewable format. This is useful | |
for visualizing the random transformations being | |
applied, for debugging purposes. | |
save_prefix: String prefix to use for saving sample | |
images (if `save_to_dir` is set). | |
save_format: Format to use for saving sample images | |
(if `save_to_dir` is set). | |
subset: Subset of data (`"training"` or `"validation"`) if | |
validation_split is set in ImageDataGenerator. | |
interpolation: Interpolation method used to resample the image if the | |
target size is different from that of the loaded image. | |
Supported methods are "nearest", "bilinear", and "bicubic". | |
If PIL version 1.1.3 or newer is installed, "lanczos" is also | |
supported. If PIL version 3.4.0 or newer is installed, "box" and | |
"hamming" are also supported. By default, "nearest" is used. | |
# Examples | |
```python | |
train_datagen = ImageDataGenerator( | |
rescale=1./255, | |
shear_range=0.2, | |
zoom_range=0.2, | |
horizontal_flip=True) | |
filenames = ['path/to/file1.png', 'path/to/file2.png', ...] | |
fileClasses = ['scottish_deerhound', 'entlebucher', ...] | |
fileListIterator = FileListIterator( | |
filenames, | |
fileClasses, | |
train_datagen, | |
target_size=(256, 256), | |
color_mode='grayscale', | |
classes=None, | |
class_mode='categorical', | |
data_format=train_datagen.data_format, | |
batch_size=32, | |
shuffle=True, | |
seed=None, | |
save_to_dir=None, | |
save_prefix='', | |
save_format='png', | |
follow_links=False, | |
subset=None, | |
interpolation='nearest') | |
``` | |
""" | |
def __init__(self, | |
filenames, | |
fileClasses, | |
image_data_generator, | |
target_size=(256, 256), | |
color_mode='rgb', | |
classes=None, | |
class_mode='categorical', | |
batch_size=32, | |
shuffle=True, | |
seed=None, | |
data_format=None, | |
save_to_dir=None, | |
save_prefix='', | |
save_format='png', | |
follow_links=False, | |
subset=None, | |
interpolation='nearest'): | |
if data_format is None: | |
data_format = K.image_data_format() | |
# self.directory = directory | |
self.image_data_generator = image_data_generator | |
self.target_size = tuple(target_size) | |
if color_mode not in {'rgb', 'grayscale'}: | |
raise ValueError('Invalid color mode:', color_mode, | |
'; expected "rgb" or "grayscale".') | |
self.color_mode = color_mode | |
self.data_format = data_format | |
if self.color_mode == 'rgb': | |
if self.data_format == 'channels_last': | |
self.image_shape = self.target_size + (3,) | |
else: | |
self.image_shape = (3,) + self.target_size | |
else: | |
if self.data_format == 'channels_last': | |
self.image_shape = self.target_size + (1,) | |
else: | |
self.image_shape = (1,) + self.target_size | |
self.classes = classes | |
if class_mode not in {'categorical', 'binary', 'sparse', | |
'input', None}: | |
raise ValueError('Invalid class_mode:', class_mode, | |
'; expected one of "categorical", ' | |
'"binary", "sparse", "input"' | |
' or None.') | |
self.class_mode = class_mode | |
self.save_to_dir = save_to_dir | |
self.save_prefix = save_prefix | |
self.save_format = save_format | |
self.interpolation = interpolation | |
if subset is not None: | |
validation_split = self.image_data_generator._validation_split | |
if subset == 'validation': | |
split = (0, validation_split) | |
elif subset == 'training': | |
split = (validation_split, 1) | |
else: | |
raise ValueError('Invalid subset name: ', subset, | |
'; expected "training" or "validation"') | |
else: | |
split = None | |
self.subset = subset | |
white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm', 'tif', 'tiff'} | |
# first, count the number of samples and classes | |
self.samples = 0 | |
if not classes: | |
classes = list(set(fileClasses)) | |
self.num_classes = len(classes) | |
self.class_indices = dict(zip(classes, range(len(classes)))) | |
self.samples = len(filenames) | |
print('Found %d images belonging to %d classes.' % (self.samples, self.num_classes)) | |
# second, build an index of the images in the different class subfolders | |
results = [] | |
self.filenames = filenames | |
# self.fileClasses = fileClasses | |
self.classes = np.zeros((self.samples,), dtype='int32') | |
i = 0 | |
for category in fileClasses: | |
self.classes[i] = self.class_indices[category] | |
i+=1 | |
super(FileListIterator, self).__init__(self.samples, batch_size, shuffle, seed) | |
def _get_batches_of_transformed_samples(self, index_array): | |
batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=K.floatx()) | |
grayscale = self.color_mode == 'grayscale' | |
# build batch of image data | |
for i, j in enumerate(index_array): | |
filename = self.filenames[j] | |
img = load_img(filename, | |
grayscale=grayscale, | |
target_size=self.target_size, | |
interpolation=self.interpolation) | |
x = img_to_array(img, data_format=self.data_format) | |
x = self.image_data_generator.random_transform(x) | |
x = self.image_data_generator.standardize(x) | |
batch_x[i] = x | |
# optionally save augmented images to disk for debugging purposes | |
if self.save_to_dir: | |
for i, j in enumerate(index_array): | |
img = array_to_img(batch_x[i], self.data_format, scale=True) | |
filename = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix, | |
index=j, | |
hash=np.random.randint(1e7), | |
format=self.save_format) | |
img.save(os.path.join(self.save_to_dir, filename)) | |
# build batch of labels | |
if self.class_mode == 'input': | |
batch_y = batch_x.copy() | |
elif self.class_mode == 'sparse': | |
batch_y = self.classes[index_array] | |
elif self.class_mode == 'binary': | |
batch_y = self.classes[index_array].astype(K.floatx()) | |
elif self.class_mode == 'categorical': | |
batch_y = np.zeros((len(batch_x), self.num_classes), dtype=K.floatx()) | |
for i, label in enumerate(self.classes[index_array]): | |
batch_y[i, label] = 1. | |
else: | |
return batch_x | |
return batch_x, batch_y | |
def next(self): | |
"""For python 2.x. | |
# Returns | |
The next batch. | |
""" | |
with self.lock: | |
index_array = next(self.index_generator) | |
# The transformation of images is not under thread lock | |
# so it can be done in parallel | |
return self._get_batches_of_transformed_samples(index_array) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment