Skip to content

Instantly share code, notes, and snippets.

@jsvisa
Last active May 23, 2018 08:47
Show Gist options
  • Save jsvisa/9ed8d95630f69a00b0b012396ab46d5c to your computer and use it in GitHub Desktop.
Save jsvisa/9ed8d95630f69a00b0b012396ab46d5c to your computer and use it in GitHub Desktop.
convert_data.py
# coding=utf-8
import math
import os
import random
import tensorflow as tf
import dataset_utils
import argparse
# 定义一些全局常量
# TRAIN_FRACTION: 定义训练数据比例
TRAIN_FRACTION = 0.6
# TEST_FRACTION: 定义测试数据比例
TEST_FRACTION = 0.2
# 定义随机数种子值
_RANDOM_SEED = 0
# 定义SHARD数量
_NUM_SHARDS = 2
# DATASET_SUBDIR: 定义数据子文件名
DATASET_SUBDIR = "dataset"
# 定义任务名称
TASKS = ["main"]
split_to_count = {}
# 获取全部文件名和数据标签
def _get_all_filenames(dataset_dir, class_label):
filenames_and_labels = []
for filename in tf.gfile.ListDirectory(dataset_dir):
if filename == DATASET_SUBDIR:
continue
path = os.path.join(dataset_dir, filename)
if tf.gfile.IsDirectory(path):
filenames_and_labels.extend(_get_all_filenames(path, class_label))
elif dataset_utils.is_picture_file(filename):
filenames_and_labels.append((path, class_label))
return filenames_and_labels
# 获取转换为TFRecord类型的数据文件名
def _get_dataset_filename(dataset_dir, split_name, task_name, shard_id):
output_filename = '%s_%s_%05d-of-%05d.tfrecord' % (
task_name, split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)
# 数据转换,读取图像数据并将数据转换成TFRecord格式
def _convert_dataset(split_name, task_name, filenames_and_labels, dataset_dir):
num_per_shard = int(math.ceil(len(filenames_and_labels) / float(_NUM_SHARDS)))
print("For task %s, split %s, read data and convert into %d shards with %d examples per shard" % (
task_name, split_name, _NUM_SHARDS, num_per_shard))
with tf.Graph().as_default():
image_reader = dataset_utils.ImageReader()
with tf.Session() as sess:
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(dataset_dir, split_name, task_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
print("Start reading images in shard %d" % shard_id)
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id+1) * num_per_shard, len(filenames_and_labels))
for i in range(start_ndx, end_ndx):
image_data = tf.gfile.FastGFile(filenames_and_labels[i][0], 'r').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_id = filenames_and_labels[i][1]
example = dataset_utils.image_to_tfexample(image_data, 'jpg', height, width, class_id, filenames_and_labels[i][0])
tfrecord_writer.write(example.SerializeToString())
# 输出数据集
def _output_dataset(filenames_and_labels, labels_to_class_names, task_name, dataset_dir):
random.shuffle(filenames_and_labels)
all_examples = len(filenames_and_labels)
train_examples = int(all_examples * TRAIN_FRACTION)
test_examples = int(all_examples * TEST_FRACTION)
validate_examples = all_examples - train_examples - test_examples
print("Generated %d training examples, %d testing examples and %d validation examples for the main task. " % (
train_examples, test_examples, validate_examples))
_convert_dataset('test', task_name, filenames_and_labels[:test_examples], dataset_dir)
_convert_dataset('train', task_name, filenames_and_labels[test_examples:train_examples + test_examples], dataset_dir)
_convert_dataset('validate', task_name, filenames_and_labels[train_examples + test_examples:], dataset_dir)
dataset_utils.write_label_file(
labels_to_class_names, dataset_dir, "%s_labels_to_class.txt" % task_name)
global split_to_count
split_to_count["%s_train" % task_name] = train_examples
split_to_count["%s_test" % task_name] = test_examples
split_to_count["%s_validate" % task_name] = validate_examples
# 主要数据转换函数,将输入的数据集转换后存储到 output_dir
def _process_main(dataset_dir, output_dir):
labels_to_class_names = {}
filenames_and_labels = []
cur_id = 0
for filename in tf.gfile.ListDirectory(dataset_dir):
if filename == DATASET_SUBDIR:
continue
path = os.path.join(dataset_dir, filename)
if tf.gfile.IsDirectory(path):
filenames_and_labels.extend(_get_all_filenames(path, cur_id))
labels_to_class_names[cur_id] = filename
cur_id += 1
_output_dataset(filenames_and_labels, labels_to_class_names, "main", output_dir)
# 处理数据标签
def _process_labels(dataset_dir, output_dir, task_name, cur_dict):
labels_to_class_names = {}
filenames_and_labels = []
cur_id = 0
for fir_dir_name in tf.gfile.ListDirectory(dataset_dir):
if fir_dir_name == DATASET_SUBDIR:
continue
if not fir_dir_name in cur_dict:
continue
path = os.path.join(dataset_dir, fir_dir_name)
if not tf.gfile.IsDirectory(path):
continue
for sec_dir_name in tf.gfile.ListDirectory(path):
parts = sec_dir_name.split("-")
if not parts[1] in cur_dict[fir_dir_name]:
continue
subdir = os.path.join(path, sec_dir_name)
if not tf.gfile.IsDirectory(subdir):
continue
filenames_and_labels.extend(_get_all_filenames(subdir, cur_id))
labels_to_class_names[cur_id] = sec_dir_name
cur_id += 1
_output_dataset(filenames_and_labels, labels_to_class_names, task_name, output_dir)
# 运行函数,获取数据文件地址,转换数据
def run(dataset_dir, output_dir):
tf.load_file_system_library("/root/caicloud/tensorflow/bazel-bin/tensorflow/core/platform/vulture/vulture_file_system.so")
output_dir = os.path.join(output_dir, DATASET_SUBDIR)
if not tf.gfile.Exists(dataset_dir):
raise("Dataset does not exist.")
if tf.gfile.Exists(output_dir):
tf.gfile.DeleteRecursively(output_dir)
tf.gfile.MakeDirs(output_dir)
random.seed(_RANDOM_SEED)
_process_main(dataset_dir, output_dir)
global split_to_count
dataset_utils.write_count_file(split_to_count, output_dir)
print('\nFinished converting the dataset!')
# 定义传入参数
def parse_args():
parser = argparse.ArgumentParser(description="Convert data to TFRecord format")
parser.add_argument('--dataset_name', dest='dataset_name', help='Dataset Name', default='data/demo_single_label')
parser.add_argument('--output_name', dest='output_name', help='Output Name', default='data/dataset')
parser.add_argument('--train_fraction', dest='train_fraction', help='Train Fraction', default='0.6')
parser.add_argument('--test_fraction', dest='test_fraction', help='Test Fraction', default='0.2')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
TRAIN_FRACTION = float(args.train_fraction)
TEST_FRACTION = float(args.test_fraction)
run(args.dataset_name, args.output_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment