Created
December 6, 2016 20:39
-
-
Save amir-rahnama/408301bc5bc07bc5afa8748513ab9477 to your computer and use it in GitHub Desktop.
Write Your Own Custom Image Dataset for Tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A generic module to read data.""" | |
import numpy | |
import collections | |
from tensorflow.python.framework import dtypes | |
class DataSet(object): | |
"""Dataset class object.""" | |
def __init__(self, | |
images, | |
labels, | |
fake_data=False, | |
one_hot=False, | |
dtype=dtypes.float64, | |
reshape=True): | |
"""Initialize the class.""" | |
if reshape: | |
assert images.shape[3] == 1 | |
images = images.reshape(images.shape[0], | |
images.shape[1] * images.shape[2]) | |
self._images = images | |
self._num_examples = images.shape[0] | |
self._labels = labels | |
self._epochs_completed = 0 | |
self._index_in_epoch = 0 | |
@property | |
def images(self): | |
return self._images | |
@property | |
def labels(self): | |
return self._labels | |
@property | |
def num_examples(self): | |
return self._num_examples | |
@property | |
def epochs_completed(self): | |
return self._epochs_completed | |
def next_batch(self, batch_size, fake_data=False): | |
"""Return the next `batch_size` examples from this data set.""" | |
start = self._index_in_epoch | |
self._index_in_epoch += batch_size | |
if self._index_in_epoch > self._num_examples: | |
# Finished epoch | |
self._epochs_completed += 1 | |
# Shuffle the data | |
perm = numpy.arange(self._num_examples) | |
numpy.random.shuffle(perm) | |
self._images = self._images[perm] | |
self._labels = self._labels[perm] | |
# Start next epoch | |
start = 0 | |
self._index_in_epoch = batch_size | |
assert batch_size <= self._num_examples | |
end = self._index_in_epoch | |
return self._images[start:end], self._labels[start:end] | |
def read_data_sets(train_dir, fake_data=False, one_hot=False, | |
dtype=dtypes.float64, reshape=True, | |
validation_size=5000): | |
"""Set the images and labels.""" | |
num_training = 3000 | |
num_validation = 1000 | |
num_test = 1000 | |
all_images = numpy.load('./npy/grey.npy') | |
all_images = all_images.reshape(all_images.shape[0], | |
all_images.shape[1], all_images.shape[2], 1) | |
train_labels_original = numpy.load('./npy/label.npy') | |
all_labels = numpy.asarray(range(0, len(train_labels_original))) | |
all_labels = dense_to_one_hot(all_labels, len(all_labels)) | |
mask = range(num_training) | |
train_images = all_images[mask] | |
train_labels = all_labels[mask] | |
mask = range(num_training, num_training + num_validation) | |
validation_images = all_images[mask] | |
validation_labels = all_labels[mask] | |
mask = range(num_training + num_validation, num_training + num_validation + num_test) | |
test_images = all_images[mask] | |
test_labels = all_labels[mask] | |
train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape) | |
validation = DataSet(validation_images, validation_labels, dtype=dtype, | |
reshape=reshape) | |
test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape) | |
ds = collections.namedtuple('Datasets', ['train', 'validation', 'test']) | |
return ds(train=train, validation=validation, test=test) | |
def dense_to_one_hot(labels_dense, num_classes): | |
"""Convert class labels from scalars to one-hot vectors.""" | |
num_labels = labels_dense.shape[0] | |
index_offset = numpy.arange(num_labels) * num_classes | |
labels_one_hot = numpy.zeros((num_labels, num_classes)) | |
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 | |
return labels_one_hot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import sys | |
import dataset | |
import tensorflow as tf | |
FLAGS = None | |
def main(_): | |
"""Run the NN.""" | |
mnist = dataset.read_data_sets(FLAGS.data_dir, one_hot=True) | |
x = tf.placeholder(tf.float32, [None, 10000]) | |
w = tf.Variable(tf.zeros([10000, 5000])) | |
b = tf.Variable(tf.zeros([5000])) | |
y = tf.matmul(x, w) + b | |
# Define loss and optimizer | |
y_ = tf.placeholder(tf.float32, [None, 5000]) | |
# The raw formulation of cross-entropy, | |
# | |
# tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), | |
# reduction_indices=[1])) | |
# | |
# can be numerically unstable. | |
# | |
# So here we use tf.nn.softmax_cross_entropy_with_logits on the raw | |
# outputs of 'y', and then average across the batch. | |
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y,y_)) | |
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) | |
sess = tf.InteractiveSession() | |
tf.global_variables_initializer().run() | |
# Train | |
for _ in range(1000): | |
batch_xs, batch_ys = mnist.train.next_batch(100) | |
print(batch_xs.shape) | |
print(batch_ys.shape) | |
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) | |
# Test trained model | |
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
print(sess.run(accuracy, feed_dict={x: mnist.test.images, | |
y_: mnist.test.labels})) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/ \ | |
input_data', help='Directory for storing input data') | |
FLAGS, unparsed = parser.parse_known_args() | |
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |
What if I wanted to load the images on the fly? like during training, instead of loading all the data beforehand.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@tandcredosouza
converting them as np.array(...) with a for loop