Created
April 27, 2019 03:08
-
-
Save steermomo/248a629075b2f46e74b81a38a09f61f4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import keras | |
class DataGenerator(keras.utils.Sequence): | |
'Generates data for Keras' | |
def __init__(self, list_IDs, labels, batch_size=32, dim=(32,32,32), n_channels=1, | |
n_classes=10, shuffle=True): | |
'Initialization' | |
self.dim = dim | |
self.batch_size = batch_size | |
self.labels = labels | |
self.list_IDs = list_IDs | |
self.n_channels = n_channels | |
self.n_classes = n_classes | |
self.shuffle = shuffle | |
self.on_epoch_end() | |
def __len__(self): | |
'Denotes the number of batches per epoch' | |
return int(np.floor(len(self.list_IDs) / self.batch_size)) | |
def __getitem__(self, index): | |
'Generate one batch of data' | |
# Generate indexes of the batch | |
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] | |
# Find list of IDs | |
list_IDs_temp = [self.list_IDs[k] for k in indexes] | |
# Generate data | |
X, y = self.__data_generation(list_IDs_temp) | |
return X, y | |
def on_epoch_end(self): | |
'Updates indexes after each epoch' | |
self.indexes = np.arange(len(self.list_IDs)) | |
if self.shuffle == True: | |
np.random.shuffle(self.indexes) | |
def __data_generation(self, list_IDs_temp): | |
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels) | |
# Initialization | |
X = np.empty((self.batch_size, *self.dim, self.n_channels)) | |
y = np.empty((self.batch_size), dtype=int) | |
# Generate data | |
for i, ID in enumerate(list_IDs_temp): | |
# Store sample | |
X[i,] = np.load('data/' + ID + '.npy') | |
# Store class | |
y[i] = self.labels[ID] | |
return X, keras.utils.to_categorical(y, num_classes=self.n_classes) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment