Created
October 27, 2017 09:19
-
-
Save qfgaohao/51556faa527fba89a81d048dda37c504 to your computer and use it in GitHub Desktop.
Convert an ImageNet like dataset into tfRecord files, provide a method get_dataset to read the created files. It has similar functions as ImageFolder in Pytorch. Modified from https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/download_and_convert_flowers.py https://github.com/tensorflow/models/blob/master/research…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
r"""Convert an ImageNet like dataset into tfRecord files, provide a method get_dataset to read the created files. | |
It has similar functions as ImageFolder in Pytorch. | |
Modified from | |
https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/download_and_convert_flowers.py | |
https://github.com/tensorflow/models/blob/master/research/slim/datasets/flowers.py | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import math | |
import os | |
import random | |
import sys | |
import argparse | |
from PIL import Image | |
import tensorflow as tf | |
# use your location of the code folder 'tensorflow/models/tree/master/research/slim' | |
sys.path.append(r"/Users/hao/data/slim_models/models/research/slim") | |
from datasets import dataset_utils | |
from tensorflow.contrib import slim | |
# Seed for repeatability. | |
_RANDOM_SEED = 0 | |
def get_dataset(dataset_dir, dataset_type): | |
"""Create a dataset from tfRecord files. | |
dataset_type is used to specify train, test and validation data. | |
For example, if the dataset_type is train, files satisfying the | |
pattern "train_*.tfrecord" in dataset_dir will be treated as | |
data sources for this dataset. | |
An example of using it: | |
with tf.Graph().as_default(): | |
dataset = get_dataset(dataset_dir, dataset_type) | |
data_provider = slim.dataset_data_provider.DatasetDataProvider( | |
dataset, common_queue_capacity=32, common_queue_min=1) | |
image, label = data_provider.get(['image', 'label']) | |
with tf.Session() as sess: | |
with slim.queues.QueueRunners(sess): | |
for i in range(4): | |
np_image, np_label = sess.run([image, label]) | |
height, width, _ = np_image.shape | |
class_name = name = dataset.labels_to_names[np_label] | |
plt.figure() | |
plt.imshow(np_image) | |
plt.title('%s, %d x %d' % (name, height, width)) | |
plt.axis('off') | |
plt.show() | |
""" | |
file_pattern = os.path.join(dataset_dir, '{}_*.tfrecord'.format(dataset_type)) | |
keys_to_features = { | |
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), | |
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), | |
'image/class/label': tf.FixedLenFeature( | |
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), | |
} | |
items_to_handlers = { | |
'image': slim.tfexample_decoder.Image(), | |
'label': slim.tfexample_decoder.Tensor('image/class/label'), | |
} | |
decoder = slim.tfexample_decoder.TFExampleDecoder( | |
keys_to_features, items_to_handlers) | |
labels_to_names = None | |
if dataset_utils.has_labels(dataset_dir): | |
labels_to_names = dataset_utils.read_label_file(dataset_dir) | |
with open(os.path.join(dataset_dir, dataset_type + '.num')) as f: | |
num_samples = int(f.read().strip()) | |
return slim.dataset.Dataset( | |
data_sources=file_pattern, | |
reader=tf.TFRecordReader, | |
decoder=decoder, | |
num_samples=num_samples, | |
items_to_descriptions=None, | |
num_classes=len(labels_to_names), | |
labels_to_names=labels_to_names) | |
def get_image_size(image_file): | |
img = Image.open(image_file) | |
return img.size | |
def get_image_type(image_file): | |
file_name = os.path.basename(image_file).lower() | |
pos = file_name.find('.') | |
return file_name[pos + 1:].encode(encoding="ascii") | |
def _get_filenames_and_classes(image_dir_root): | |
"""Returns a list of filenames and inferred class names. | |
Args: | |
dataset_dir: A directory containing a set of subdirectories representing | |
class names. Each subdirectory should contain PNG or JPG encoded images. | |
Returns: | |
A list of image file paths, relative to `dataset_dir` and the list of | |
subdirectories, representing class names. | |
""" | |
directories = [] | |
class_names = [] | |
for filename in os.listdir(image_dir_root): | |
path = os.path.join(image_dir_root, filename) | |
if os.path.isdir(path): | |
directories.append(path) | |
class_names.append(filename) | |
photo_filenames = [] | |
for directory in directories: | |
for filename in os.listdir(directory): | |
path = os.path.join(directory, filename) | |
photo_filenames.append(path) | |
return photo_filenames, sorted(class_names) | |
def _get_dataset_filename(dataset_dir, split_name, shard_id, shard_num): | |
output_filename = '%s_%05d-of-%05d.tfrecord' % ( | |
split_name, shard_id, shard_num) | |
return os.path.join(dataset_dir, output_filename) | |
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir, shard_num): | |
"""Converts the given filenames to a TFRecord dataset. | |
Args: | |
split_name: The name of the dataset, either 'train' or 'validation'. | |
filenames: A list of absolute paths to png or jpg images. | |
class_names_to_ids: A dictionary from class names (strings) to ids | |
(integers). | |
dataset_dir: The directory where the converted datasets are stored. | |
""" | |
num_per_shard = int(math.ceil(len(filenames) / float(shard_num))) | |
with tf.Graph().as_default(): | |
with tf.Session('') as sess: | |
for shard_id in range(shard_num): | |
output_filename = _get_dataset_filename( | |
dataset_dir, split_name, shard_id, shard_num) | |
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: | |
start_ndx = shard_id * num_per_shard | |
end_ndx = min((shard_id + 1) * num_per_shard, len(filenames)) | |
for i in range(start_ndx, end_ndx): | |
sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( | |
i + 1, len(filenames), shard_id)) | |
sys.stdout.flush() | |
# Read the filename: | |
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() | |
height, width = get_image_size(filenames[i]) | |
image_type = get_image_type(filenames[i]) | |
class_name = os.path.basename(os.path.dirname(filenames[i])) | |
class_id = class_names_to_ids[class_name] | |
example = dataset_utils.image_to_tfexample( | |
image_data, image_type, height, width, class_id) | |
tfrecord_writer.write(example.SerializeToString()) | |
sys.stdout.write('\n') | |
sys.stdout.flush() | |
def convert(image_dir, sub_dir, dataset_dir, shard_num, class_names_to_ids): | |
"""Conversion operation. | |
Args: | |
dataset_dir: The dataset directory where the dataset is stored. | |
""" | |
filenames, class_names = _get_filenames_and_classes(os.path.join(image_dir, sub_dir)) | |
save_labels = False | |
if not class_names_to_ids: | |
class_names_to_ids = dict(zip(class_names, range(len(class_names)))) | |
save_labels = True | |
# Divide into train and test: | |
random.seed(_RANDOM_SEED) | |
random.shuffle(filenames) | |
# First, convert the training and validation sets. | |
_convert_dataset(sub_dir, filenames, class_names_to_ids, dataset_dir, shard_num) | |
if save_labels: | |
labels_to_class_names = dict(zip(range(len(class_names)), class_names)) | |
dataset_utils.write_label_file(labels_to_class_names, dataset_dir) | |
with open(os.path.join(dataset_dir, sub_dir + ".num"), 'w') as f: | |
f.write(str(len(filenames))) | |
return class_names_to_ids | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--image_dir', | |
type=str, | |
required=True, | |
help='Directory for images.') | |
parser.add_argument('--dataset_dir', | |
type=str, | |
required=True, | |
help='Directory for storing the parsed tfRecord files and the label file.') | |
parser.add_argument('--sub_dirs', | |
nargs='+', | |
help='the sub directory of image_dir, used for specifying train, val and test data.', | |
required=True) | |
parser.add_argument('--shard_num', | |
type=int, | |
default=1, | |
help='The number of shards.') | |
args, unparsed = parser.parse_known_args() | |
if unparsed: | |
parser.print_help() | |
sys.exit(1) | |
dataset_dir = args.dataset_dir | |
if not tf.gfile.Exists(dataset_dir): | |
tf.gfile.MakeDirs(dataset_dir) | |
class_names_to_ids = None | |
for sub_dir in args.sub_dirs: | |
class_names_to_ids = convert(args.image_dir, sub_dir, dataset_dir, args.shard_num, class_names_to_ids) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment