Skip to content

Instantly share code, notes, and snippets.

@jjallaire
Last active December 26, 2019 21:05
Show Gist options
  • Save jjallaire/e2efacb54cfc91d554aec2db764632ed to your computer and use it in GitHub Desktop.
Save jjallaire/e2efacb54cfc91d554aec2db764632ed to your computer and use it in GitHub Desktop.
'''Train MNIST with tfrecords yielded from a TF Dataset
In order to run this example you should first run 'mnist_to_tfrecord.py'
which will download MNIST data and serialize it into 3 tfrecords files
(train.tfrecords, validation.tfrecords, and test.tfrecords).
This example demonstrates the use of TF Datasets wrapped by a generator
function. The example currently only works with a fork of keras that accepts
`workers=0` as an argument to fit_generator, etc. Passing `workers=0` results
in the generator function being run on the main thread (without this various
errors ensue b/c of the way TF handles being called on a background thread).
You can install the fork with support for `workers=0` from here:
https://github.com/jjallaire/keras/tree/feature/main-thread-generator
'''
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop
import keras.backend as K
import tensorflow as tf
batch_size = 128
num_classes = 10
epochs = 20
steps_per_epoch = 500
# Return a TF dataset for specified filename(s)
def mnist_dataset(filenames):
def decode_example(example_proto):
features = tf.parse_single_example(
example_proto,
features = {
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
}
)
image = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.cast(image, tf.float32) / 255.
label = tf.one_hot(tf.cast(features['label'], tf.int32), num_classes)
return [image, label]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(decode_example)
dataset = dataset.repeat()
dataset = dataset.shuffle(10000)
dataset = dataset.batch(batch_size)
return dataset
# Keras generator that yields batches from the speicfied tfrecord filename(s)
def mnist_generator(filenames):
dataset = mnist_dataset(filenames)
iter = dataset.make_one_shot_iterator()
batch = iter.get_next()
while True:
yield K.batch_get_value(batch)
model = Sequential()
model.add(Dense(256, activation='relu', input_shape=(784,)))
model.add(Dropout(0.4))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(num_classes, activation='softmax'))
model.summary()
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy'])
history = model.fit_generator(
mnist_generator('mnist/train.tfrecords'),
steps_per_epoch=steps_per_epoch,
epochs=epochs,
verbose=1,
validation_data=mnist_generator('mnist/validation.tfrecords'),
validation_steps=steps_per_epoch,
workers = 0 # runs generator on the main thread
)
score = model.evaluate_generator(
mnist_generator('mnist/test.tfrecords'),
steps=steps_per_epoch,
workers = 0 # runs generator on the main thread
)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts MNIST data to TFRecords file format with Example protos."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets import mnist
FLAGS = None
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to(data_set, name):
"""Converts a dataset to tfrecords."""
images = data_set.images
labels = data_set.labels
num_examples = data_set.num_examples
if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
(images.shape[0], num_examples))
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]
filename = os.path.join(FLAGS.directory, name + '.tfrecords')
print('Writing', filename)
with tf.python_io.TFRecordWriter(filename) as writer:
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
def main(unused_argv):
# Get the data.
data_sets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size)
# Convert to Examples and write the result to TFRecords.
convert_to(data_sets.train, 'train')
convert_to(data_sets.validation, 'validation')
convert_to(data_sets.test, 'test')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--directory',
type=str,
default='mnist',
help='Directory to download data files and write the converted result'
)
parser.add_argument(
'--validation_size',
type=int,
default=5000,
help="""\
Number of examples to separate from the training data for the validation
set.\
"""
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
@aleio1
Copy link

aleio1 commented Mar 22, 2019

Hi, is it normal that running this code with tensorflow-gpu I have a GPU usage only about 4%?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment