Skip to content

Instantly share code, notes, and snippets.

@vihari
Created May 10, 2018 08:35
Show Gist options
  • Save vihari/43df91579ad8cce7c548228f2e3365dd to your computer and use it in GitHub Desktop.
Save vihari/43df91579ad8cce7c548228f2e3365dd to your computer and use it in GitHub Desktop.
Script to export to TFRecords
"""
Exports data into tfrecords to the save_dir
train_data, validation_data and test_data are list of tuples containing: (image_data, label, domain id, file_path (if available))
"""
def export_tfrecord(save_dir, train_data, validation_data, test_data):
import math
import itertools
random.shuffle(train_data)
splits = ["train", "validation", "test"]
for di, data in enumerate([train_data, validation_data, test_data]):
num_per_shard = int(math.ceil(len(data) / float(_NUM_SHARDS)))
split_name = splits[di]
if len(data)==0:
continue
with tf.Graph().as_default():
with tf.Session('') as sess:
cl_dict, uid_dict = {}, {}
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(
save_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id + 1) * num_per_shard, len(data))
for i in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i + 1, len(data), shard_id))
sys.stdout.flush()
# Read the filename:
image_data, label, uid, file_path = data[i]
example = tf.train.Example(features=tf.train.Features(feature={
'image': tf.train.Feature(bytes_list=tf.train.BytesList(
value=[image_data.tobytes()])),
'format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['raw'])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
'uid': tf.train.Feature(int64_list=tf.train.Int64List(value=[uid])),
'file_path': tf.train.Feature(bytes_list=tf.train.BytesList(value=[file_path]))
}))
cl_dict[label] = cl_dict.get(label, 0)+1
uid_dict[uid] = uid_dict.get(uid, 0)+1
tfrecord_writer.write(example.SerializeToString())
print ("\nClass labels: %s" % cl_dict)
print ("UIDs: %s" % uid_dict)
sys.stdout.write('\n')
sys.stdout.flush()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment