Created
March 1, 2019 20:42
-
-
Save bzamecnik/368fc4b43847c449b3c297fc0056b2ba to your computer and use it in GitHub Desktop.
Example of basic MNIST Keras model with tf.Dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example of basic MNIST Keras model with tf.Dataset | |
# More up-to-date version of: https://github.com/keras-team/keras/blob/master/examples/mnist_dataset_api.py | |
""" | |
MNIST classification with TensorFlow's Dataset API. | |
Introduced in TensorFlow 1.3, the Dataset API is now the | |
standard method for loading data into TensorFlow models. | |
A Dataset is a sequence of elements, which are themselves | |
composed of tf.Tensor components. For more details, see: | |
https://www.tensorflow.org/programmers_guide/datasets | |
To use this with Keras, we make a dataset out of elements | |
of the form (input batch, output batch). From there, we | |
create a one-shot iterator and a graph node corresponding | |
to its get_next() method. These tensors are then provided | |
to the network instead of plain numpy arrays. | |
See also the mnist_tfrecord.py example. | |
""" | |
import numpy as np | |
import os | |
import tempfile | |
import keras | |
from keras import backend as K | |
from keras import layers | |
from keras.datasets import mnist | |
import tensorflow as tf | |
if K.backend() != 'tensorflow': | |
raise RuntimeError('This example can only run with the TensorFlow backend,' | |
' because it requires the Datset API, which is not' | |
' supported on other platforms.') | |
batch_size = 128 | |
shuffle_size = 1024 | |
epochs = 5 | |
num_classes = 10 | |
def build_model(): | |
input = layers.Input(shape=(28, 28)) | |
# add a dimension for conv channels | |
x = layers.Lambda(K.expand_dims)(input) | |
x = layers.Conv2D(32, (3, 3), | |
activation='relu', padding='valid')(x) | |
x = layers.MaxPooling2D(pool_size=(2, 2))(x) | |
x = layers.Conv2D(64, (3, 3), activation='relu')(x) | |
x = layers.MaxPooling2D(pool_size=(2, 2))(x) | |
x = layers.Flatten()(x) | |
x = layers.Dense(512, activation='relu')(x) | |
x = layers.Dropout(0.5)(x) | |
output = layers.Dense(num_classes, activation='softmax')(x) | |
model = keras.models.Model(inputs=input, outputs=output) | |
model.compile(optimizer=keras.optimizers.RMSprop(lr=2e-3, decay=1e-5), | |
loss='categorical_crossentropy', | |
metrics=['accuracy']) | |
return model | |
def make_dataset(x, y, shuffle=False): | |
def preprocess(image, label): | |
"""Preprocess raw data to trainable input.""" | |
x = tf.cast(image, tf.float32) / 255 | |
y = tf.one_hot(tf.cast(label, tf.uint8), num_classes) | |
return x, y | |
# NOTE: This stored the provided numpy arrays into the | |
# TF graph as constants! It's only useful for small data. | |
dataset = tf.data.Dataset.from_tensor_slices((x, y)) | |
dataset = dataset.map(preprocess) | |
dataset = dataset.repeat() | |
if shuffle: | |
dataset = dataset.shuffle(shuffle_size) | |
# Keras does not support tensors with dynamic batch size | |
dataset = dataset.batch(batch_size, drop_remainder=True) | |
iterator = dataset.make_one_shot_iterator() | |
inputs, targets = iterator.get_next() | |
return inputs, targets | |
# numpy arrays | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
model = build_model() | |
model.summary() | |
# tensors | |
inputs_train, targets_train = make_dataset(x_train, y_train, shuffle=True) | |
inputs_test, targets_test = make_dataset(x_test, y_test, shuffle=False) | |
steps_per_epoch = int(np.ceil(len(x_train) / float(batch_size))) # = 469 | |
validation_steps = int(np.ceil(len(x_test) / float(batch_size))) # = 79 | |
# Since upstream Keras 2.2.0 it's possible to provide tensors for | |
# training and validation inputs/outputs, while tf.keras directly | |
# accepts a tf.data.Dataset. | |
model.fit(x=inputs_train, | |
y=targets_train, | |
epochs=epochs, | |
steps_per_epoch=steps_per_epoch, | |
validation_data=(inputs_test, targets_test), | |
validation_steps=validation_steps) | |
loss, acc = model.evaluate(inputs_test, targets_test, steps=validation_steps) | |
print('\nTest accuracy: {0}'.format(acc)) | |
# The model can be then used either with numpy arrays or other tensors. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment