Created
September 9, 2017 04:03
-
-
Save Hironsan/e041d6606164bc14c50aa56b989c5fc0 to your computer and use it in GitHub Desktop.
fit vs fit_generator in Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras.preprocessing import sequence | |
from keras.models import Sequential | |
from keras.layers import Dense, Embedding | |
from keras.layers import LSTM | |
from keras.datasets import imdb | |
def batch_iter(data, labels, batch_size, shuffle=True): | |
num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1 | |
def data_generator(): | |
data_size = len(data) | |
while True: | |
# Shuffle the data at each epoch | |
if shuffle: | |
shuffle_indices = np.random.permutation(np.arange(data_size)) | |
shuffled_data = data[shuffle_indices] | |
shuffled_labels = labels[shuffle_indices] | |
else: | |
shuffled_data = data | |
shuffled_labels = labels | |
for batch_num in range(num_batches_per_epoch): | |
start_index = batch_num * batch_size | |
end_index = min((batch_num + 1) * batch_size, data_size) | |
X, y = shuffled_data[start_index: end_index], shuffled_labels[start_index: end_index] | |
yield X, y | |
return num_batches_per_epoch, data_generator() | |
def main(mode): | |
max_features = 20000 | |
maxlen = 80 | |
batch_size = 32 | |
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) | |
x_train = sequence.pad_sequences(x_train, maxlen=maxlen) | |
x_test = sequence.pad_sequences(x_test, maxlen=maxlen) | |
model = Sequential() | |
model.add(Embedding(max_features, 128)) | |
model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2)) | |
model.add(Dense(1, activation='sigmoid')) | |
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) | |
if mode == 'fit': | |
model.fit(x_train, y_train, batch_size=batch_size, epochs=1, validation_data=(x_test, y_test)) | |
else: | |
train_steps, train_batches = batch_iter(x_train, y_train, batch_size) | |
valid_steps, valid_batches = batch_iter(x_test, y_test, batch_size) | |
model.fit_generator(train_batches, train_steps, epochs=1, validation_data=valid_batches, validation_steps=valid_steps) | |
if __name__ == '__main__': | |
import sys | |
mode = sys.argv[1] | |
main(mode) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Great code! Simple and very useful! One question regarding line 10. I wonder why this line is not like this:
I guess you tried to avoid the clip done in a division between integers, that can be achieved by casting the denominator as float. Or there is another reason that I'm missing?
Thanks!