Created
March 6, 2020 17:19
-
-
Save denisb411/9151a15786f4b48abaad53fc6ff3ad9a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Usage: | |
# From tensorflow/models/ | |
# Create train data: | |
python generate_tfrecord.py --csv_input=images/train_labels.csv --image_dir=images/train --output_path=train.record | |
# Create test data: | |
python generate_tfrecord.py --csv_input=images/test_labels.csv --image_dir=images/test --output_path=test.record | |
""" | |
from __future__ import division | |
from __future__ import print_function | |
from __future__ import absolute_import | |
import os | |
import io | |
import pandas as pd | |
import tensorflow as tf | |
from PIL import Image | |
from object_detection.utils import dataset_util | |
from collections import namedtuple, OrderedDict | |
from argparse import ArgumentParser | |
argParser = ArgumentParser() | |
argParser.add_argument('--classes', nargs='*', help='List of classes', required=True) | |
argParser.add_argument('--csv_input', dest="csv_input", type=str) | |
argParser.add_argument('--image_dir', dest="image_dir", type=str) | |
argParser.add_argument('--output_path', dest="output_path", type=str) | |
FLAGS = argParser.parse_args() | |
def split(df, group): | |
data = namedtuple('data', ['filename', 'object']) | |
gb = df.groupby(group) | |
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] | |
def create_tf_example(group, path): | |
with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: | |
encoded_jpg = fid.read() | |
encoded_jpg_io = io.BytesIO(encoded_jpg) | |
image = Image.open(encoded_jpg_io) | |
width, height = image.size | |
filename = group.filename.encode('utf8') | |
image_format = b'jpg' | |
xmins = [] | |
xmaxs = [] | |
ymins = [] | |
ymaxs = [] | |
classes_text = [] | |
classes = [] | |
for index, row in group.object.iterrows(): | |
xmins.append(row['xmin'] / width) | |
xmaxs.append(row['xmax'] / width) | |
ymins.append(row['ymin'] / height) | |
ymaxs.append(row['ymax'] / height) | |
classes_text.append(row['class'].encode('utf8')) | |
if not FLAGS.classes: | |
classes.append(None) | |
else: | |
for i, c in enumerate(FLAGS.classes): | |
classes.append(i+1) | |
tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
'image/height': dataset_util.int64_feature(height), | |
'image/width': dataset_util.int64_feature(width), | |
'image/filename': dataset_util.bytes_feature(filename), | |
'image/source_id': dataset_util.bytes_feature(filename), | |
'image/encoded': dataset_util.bytes_feature(encoded_jpg), | |
'image/format': dataset_util.bytes_feature(image_format), | |
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), | |
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), | |
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), | |
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), | |
'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | |
'image/object/class/label': dataset_util.int64_list_feature(classes), | |
})) | |
return tf_example | |
def main(_): | |
writer = tf.python_io.TFRecordWriter(FLAGS.output_path) | |
path = os.path.join(os.getcwd(), FLAGS.image_dir) | |
examples = pd.read_csv(FLAGS.csv_input) | |
grouped = split(examples, 'filename') | |
for group in grouped: | |
tf_example = create_tf_example(group, path) | |
writer.write(tf_example.SerializeToString()) | |
writer.close() | |
output_path = os.path.join(os.getcwd(), FLAGS.output_path) | |
print('Successfully created the TFRecords: {}'.format(output_path)) | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment