Skip to content

Instantly share code, notes, and snippets.

@tldrafael
Created January 17, 2020 15:18
Show Gist options
  • Save tldrafael/df7fd7cbcc0b104f3bfd4f964031546e to your computer and use it in GitHub Desktop.
Save tldrafael/df7fd7cbcc0b104f3bfd4f964031546e to your computer and use it in GitHub Desktop.
Generic example of input generator for Keras
import numpy as np
class inputGen:
def __init__(self, batch_size, X, y, shuffle=True):
self.batch_size = batch_size
self.X = X
self.y = y
self.cursor = 0
self.n_samples = X.shape[0]
self.ids_sequence = np.arange(X.shape[0])
self.ids_batch = None
self.shuffle = shuffle
if shuffle:
self.shuffle_ids()
def shuffle_ids(self):
np.random.shuffle(self.ids_sequence)
def generator(self):
while True:
cursor_start = self.cursor
cursor_end = cursor_start + self.batch_size
if cursor_end > self.n_samples:
cursor_end = self.n_samples
cursor_start = np.max([0, cursor_end - self.batch_size])
ids_batch = self.ids_sequence[cursor_start:cursor_end]
yield self.X[ids_batch], self.y[ids_batch]
self.update_cursor()
def update_cursor(self):
self.cursor += self.batch_size
if self.cursor > self.n_samples:
self.cursor = 0
if self.shuffle:
self.shuffle_ids()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment