Last active
April 16, 2020 23:43
-
-
Save TimSC/c4c3316f1b0edf7a51923a58488b78b8 to your computer and use it in GitHub Desktop.
Convert nmist digit example to use generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
#Work around for https://github.com/tensorflow/tensorflow/issues/24496 | |
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' | |
# Work around for https://github.com/tensorflow/tensorflow/issues/33024 | |
import tensorflow.compat as compat | |
compat.v1.disable_eager_execution() | |
# baseline cnn model for mnist | |
import numpy as np | |
from numpy import mean | |
from numpy import std | |
from matplotlib import pyplot | |
from sklearn.model_selection import KFold | |
import tensorflow.keras as keras | |
from tensorflow.keras.datasets import mnist | |
from tensorflow.keras.utils import to_categorical | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Conv2D | |
from tensorflow.keras.layers import MaxPooling2D | |
from tensorflow.keras.layers import Dense | |
from tensorflow.keras.layers import Flatten | |
from tensorflow.keras.optimizers import SGD | |
import time | |
class DataGenerator(keras.utils.Sequence): | |
'Generates data for Keras' | |
def __init__(self, trainX, trainY, numBatches=32, shuffle=True): | |
'Initialization' | |
self.trainX = trainX | |
self.trainY = trainY | |
self.numBatches = numBatches | |
self.shuffle = shuffle | |
self.ind = np.arange(len(self.trainX)) | |
self.batchInd = np.array([(len(self.ind) * i // self.numBatches) for i in range(self.numBatches+1)], dtype=np.int64) | |
self.on_epoch_end() | |
def __len__(self): | |
'Denotes the number of batches per epoch' | |
return self.numBatches | |
def __getitem__(self, index): | |
'Generate one batch of data' | |
batchIndex1 = self.batchInd[index] | |
batchIndex2 = self.batchInd[index+1] | |
batchInd = self.ind[batchIndex1: batchIndex2] | |
return self.trainX[batchInd], self.trainY[batchInd] | |
def on_epoch_end(self): | |
if self.shuffle: | |
np.random.shuffle(self.ind) | |
# load train and test dataset | |
def load_dataset(): | |
# load dataset | |
(trainX, trainY), (testX, testY) = mnist.load_data() | |
# reshape dataset to have a single channel | |
trainX = trainX.reshape((trainX.shape[0], 28, 28, 1)) | |
testX = testX.reshape((testX.shape[0], 28, 28, 1)) | |
# one hot encode target values | |
trainY = to_categorical(trainY) | |
testY = to_categorical(testY) | |
trainX, testX = prep_pixels(trainX, testX) | |
return trainX, trainY, testX, testY | |
# scale pixels | |
def prep_pixels(train, test): | |
# convert from integers to floats | |
train_norm = train.astype('float32') | |
test_norm = test.astype('float32') | |
# normalize to range 0-1 | |
train_norm = train_norm / 255.0 | |
test_norm = test_norm / 255.0 | |
# return normalized images | |
return train_norm, test_norm | |
# define cnn model | |
def define_model(): | |
model = Sequential() | |
model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', input_shape=(28, 28, 1))) | |
model.add(MaxPooling2D((2, 2))) | |
model.add(Flatten()) | |
model.add(Dense(100, activation='relu', kernel_initializer='he_uniform')) | |
model.add(Dense(10, activation='softmax')) | |
# compile model | |
opt = SGD(lr=0.01, momentum=0.9) | |
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy']) | |
return model | |
# evaluate a model using k-fold cross-validation | |
def evaluate_model(trainGenerators, validateGenerators): | |
scores, histories = list(), list() | |
# enumerate splits | |
for trainGen, valGen in zip(trainGenerators, validateGenerators): | |
# define model | |
model = define_model() | |
startTime = time.time() | |
# fit model | |
history = model.fit_generator(trainGen, | |
validation_data=valGen, | |
epochs=10, verbose=0, | |
steps_per_epoch=len(trainGen), | |
validation_steps=len(valGen), | |
max_queue_size=100, | |
use_multiprocessing=True, | |
workers=1) | |
print ("Fit in {} sec".format(time.time()-startTime)) | |
# evaluate model | |
_, acc = model.evaluate(x=valGen, verbose=0) | |
print('> %.3f' % (acc * 100.0)) | |
# stores scores | |
scores.append(acc) | |
histories.append(history) | |
return scores, histories | |
# plot diagnostic learning curves | |
def summarize_diagnostics(histories): | |
for i in range(len(histories)): | |
# plot loss | |
pyplot.subplot(2, 1, 1) | |
pyplot.title('Cross Entropy Loss') | |
pyplot.plot(histories[i].history['loss'], color='blue', label='train') | |
pyplot.plot(histories[i].history['val_loss'], color='orange', label='test') | |
# plot accuracy | |
pyplot.subplot(2, 1, 2) | |
pyplot.title('Classification Accuracy') | |
pyplot.plot(histories[i].history['accuracy'], color='blue', label='train') | |
pyplot.plot(histories[i].history['val_accuracy'], color='orange', label='test') | |
pyplot.show() | |
# summarize model performance | |
def summarize_performance(scores): | |
# print summary | |
print('Accuracy: mean=%.3f std=%.3f, n=%d' % (mean(scores)*100, std(scores)*100, len(scores))) | |
# box and whisker plots of results | |
pyplot.boxplot(scores) | |
pyplot.show() | |
# run the test harness for evaluating a model | |
def run_test_harness(): | |
trainX, trainY, testX, testY = load_dataset() | |
trainGenerators = [] | |
validateGenerators = [] | |
kfold = KFold(5) | |
for outOfFold, inFold in kfold.split(trainX): | |
foldTrainX = trainX[outOfFold] | |
foldTrainY = trainY[outOfFold] | |
foldValX = trainX[inFold] | |
foldValY = trainY[inFold] | |
trainGenerators = [DataGenerator(foldTrainX, foldTrainY, 60000//32) for i in range(5)] | |
validateGenerators = [DataGenerator(foldValX, foldValY, 100) for i in range(5)] | |
# evaluate model | |
scores, histories = evaluate_model(trainGenerators, validateGenerators) | |
# learning curves | |
summarize_diagnostics(histories) | |
# summarize estimated performance | |
summarize_performance(scores) | |
if __name__=="__main__": | |
# entry point, run the test harness | |
run_test_harness() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment