Last active
December 26, 2019 21:05
-
-
Save jjallaire/e2efacb54cfc91d554aec2db764632ed to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''Train MNIST with tfrecords yielded from a TF Dataset | |
In order to run this example you should first run 'mnist_to_tfrecord.py' | |
which will download MNIST data and serialize it into 3 tfrecords files | |
(train.tfrecords, validation.tfrecords, and test.tfrecords). | |
This example demonstrates the use of TF Datasets wrapped by a generator | |
function. The example currently only works with a fork of keras that accepts | |
`workers=0` as an argument to fit_generator, etc. Passing `workers=0` results | |
in the generator function being run on the main thread (without this various | |
errors ensue b/c of the way TF handles being called on a background thread). | |
You can install the fork with support for `workers=0` from here: | |
https://github.com/jjallaire/keras/tree/feature/main-thread-generator | |
''' | |
from __future__ import print_function | |
import keras | |
from keras.datasets import mnist | |
from keras.models import Sequential | |
from keras.layers import Dense, Dropout | |
from keras.optimizers import RMSprop | |
import keras.backend as K | |
import tensorflow as tf | |
batch_size = 128 | |
num_classes = 10 | |
epochs = 20 | |
steps_per_epoch = 500 | |
# Return a TF dataset for specified filename(s) | |
def mnist_dataset(filenames): | |
def decode_example(example_proto): | |
features = tf.parse_single_example( | |
example_proto, | |
features = { | |
'image_raw': tf.FixedLenFeature([], tf.string), | |
'label': tf.FixedLenFeature([], tf.int64) | |
} | |
) | |
image = tf.decode_raw(features['image_raw'], tf.uint8) | |
image = tf.cast(image, tf.float32) / 255. | |
label = tf.one_hot(tf.cast(features['label'], tf.int32), num_classes) | |
return [image, label] | |
dataset = tf.data.TFRecordDataset(filenames) | |
dataset = dataset.map(decode_example) | |
dataset = dataset.repeat() | |
dataset = dataset.shuffle(10000) | |
dataset = dataset.batch(batch_size) | |
return dataset | |
# Keras generator that yields batches from the speicfied tfrecord filename(s) | |
def mnist_generator(filenames): | |
dataset = mnist_dataset(filenames) | |
iter = dataset.make_one_shot_iterator() | |
batch = iter.get_next() | |
while True: | |
yield K.batch_get_value(batch) | |
model = Sequential() | |
model.add(Dense(256, activation='relu', input_shape=(784,))) | |
model.add(Dropout(0.4)) | |
model.add(Dense(128, activation='relu')) | |
model.add(Dropout(0.3)) | |
model.add(Dense(num_classes, activation='softmax')) | |
model.summary() | |
model.compile(loss='categorical_crossentropy', | |
optimizer=RMSprop(), | |
metrics=['accuracy']) | |
history = model.fit_generator( | |
mnist_generator('mnist/train.tfrecords'), | |
steps_per_epoch=steps_per_epoch, | |
epochs=epochs, | |
verbose=1, | |
validation_data=mnist_generator('mnist/validation.tfrecords'), | |
validation_steps=steps_per_epoch, | |
workers = 0 # runs generator on the main thread | |
) | |
score = model.evaluate_generator( | |
mnist_generator('mnist/test.tfrecords'), | |
steps=steps_per_epoch, | |
workers = 0 # runs generator on the main thread | |
) | |
print('Test loss:', score[0]) | |
print('Test accuracy:', score[1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Converts MNIST data to TFRecords file format with Example protos.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import os | |
import sys | |
import tensorflow as tf | |
from tensorflow.contrib.learn.python.learn.datasets import mnist | |
FLAGS = None | |
def _int64_feature(value): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
def _bytes_feature(value): | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def convert_to(data_set, name): | |
"""Converts a dataset to tfrecords.""" | |
images = data_set.images | |
labels = data_set.labels | |
num_examples = data_set.num_examples | |
if images.shape[0] != num_examples: | |
raise ValueError('Images size %d does not match label size %d.' % | |
(images.shape[0], num_examples)) | |
rows = images.shape[1] | |
cols = images.shape[2] | |
depth = images.shape[3] | |
filename = os.path.join(FLAGS.directory, name + '.tfrecords') | |
print('Writing', filename) | |
with tf.python_io.TFRecordWriter(filename) as writer: | |
for index in range(num_examples): | |
image_raw = images[index].tostring() | |
example = tf.train.Example(features=tf.train.Features(feature={ | |
'height': _int64_feature(rows), | |
'width': _int64_feature(cols), | |
'depth': _int64_feature(depth), | |
'label': _int64_feature(int(labels[index])), | |
'image_raw': _bytes_feature(image_raw)})) | |
writer.write(example.SerializeToString()) | |
def main(unused_argv): | |
# Get the data. | |
data_sets = mnist.read_data_sets(FLAGS.directory, | |
dtype=tf.uint8, | |
reshape=False, | |
validation_size=FLAGS.validation_size) | |
# Convert to Examples and write the result to TFRecords. | |
convert_to(data_sets.train, 'train') | |
convert_to(data_sets.validation, 'validation') | |
convert_to(data_sets.test, 'test') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--directory', | |
type=str, | |
default='mnist', | |
help='Directory to download data files and write the converted result' | |
) | |
parser.add_argument( | |
'--validation_size', | |
type=int, | |
default=5000, | |
help="""\ | |
Number of examples to separate from the training data for the validation | |
set.\ | |
""" | |
) | |
FLAGS, unparsed = parser.parse_known_args() | |
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, is it normal that running this code with tensorflow-gpu I have a GPU usage only about 4%?