Last active
July 2, 2021 18:17
-
-
Save iamtodor/787aafbd15b99bf15eaf5bc31271e235 to your computer and use it in GitHub Desktop.
Create tfrecords format with pascal-like data structure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is an example of using | |
# https://github.com/tensorflow/models/blob/master/research/object_detection/dataset_tools/create_pascal_tf_record.py | |
# The structure should be like PASCAL VOC format dataset | |
# +Dataset | |
# +Annotations | |
# +JPEGImages | |
# python create_tfrecords_from_xml.py --image_dir=dataset/JPEGImages | |
# --annotations_dir=dataset/Annotations | |
# --label_map_path=object-detection.pbtxt | |
# --output_path=data.record | |
import hashlib | |
import io | |
import logging | |
import os | |
from lxml import etree | |
import PIL.Image | |
import tensorflow as tf | |
from object_detection.utils import dataset_util | |
from object_detection.utils import label_map_util | |
flags = tf.app.flags | |
flags.DEFINE_string('image_dir', '', 'Path to image directory.') | |
flags.DEFINE_string('annotations_dir', '', 'Path to annotations directory.') | |
flags.DEFINE_string('output_path', '', 'Path to output TFRecord') | |
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt', | |
'Path to label map proto') | |
FLAGS = flags.FLAGS | |
def dict_to_tf_example(data, image_dir, label_map_dict): | |
"""Convert XML derived dict to tf.Example proto. | |
Notice that this function normalizes the bounding | |
box coordinates provided by the raw data. | |
Arguments: | |
data: dict holding XML fields for a single image (obtained by | |
running dataset_util.recursive_parse_xml_to_dict) | |
image_dir: Path to image directory. | |
label_map_dict: A map from string label names to integers ids. | |
Returns: | |
example: The converted tf.Example. | |
""" | |
full_path = os.path.join(image_dir, data['filename']) | |
with tf.gfile.GFile(full_path, 'rb') as fid: | |
encoded_jpg = fid.read() | |
encoded_jpg_io = io.BytesIO(encoded_jpg) | |
image = PIL.Image.open(encoded_jpg_io) | |
if image.format != 'JPEG': | |
raise ValueError('Image format not JPEG') | |
key = hashlib.sha256(encoded_jpg).hexdigest() | |
width = int(data['size']['width']) | |
height = int(data['size']['height']) | |
xmin = [] | |
ymin = [] | |
xmax = [] | |
ymax = [] | |
classes = [] | |
classes_text = [] | |
try: | |
for obj in data['object']: | |
xmin.append(float(obj['bndbox']['xmin']) / width) | |
ymin.append(float(obj['bndbox']['ymin']) / height) | |
xmax.append(float(obj['bndbox']['xmax']) / width) | |
ymax.append(float(obj['bndbox']['ymax']) / height) | |
classes_text.append(obj['name'].encode('utf8')) | |
classes.append(label_map_dict[obj['name']]) | |
except KeyError: | |
print(data['filename'] + ' without objects!') | |
difficult_obj = [0]*len(classes) | |
example = tf.train.Example(features=tf.train.Features(feature={ | |
'image/height': dataset_util.int64_feature(height), | |
'image/width': dataset_util.int64_feature(width), | |
'image/filename': dataset_util.bytes_feature(data['filename'].encode('utf8')), | |
'image/source_id': dataset_util.bytes_feature(data['filename'].encode('utf8')), | |
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')), | |
'image/encoded': dataset_util.bytes_feature(encoded_jpg), | |
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')), | |
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin), | |
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax), | |
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin), | |
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax), | |
'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | |
'image/object/class/label': dataset_util.int64_list_feature(classes), | |
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj) | |
})) | |
return example | |
def main(_): | |
writer = tf.python_io.TFRecordWriter(FLAGS.output_path) | |
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path) | |
image_dir = FLAGS.image_dir | |
annotations_dir = FLAGS.annotations_dir | |
logging.info('Reading from dataset: ' + annotations_dir) | |
examples_list = os.listdir(annotations_dir) | |
for idx, example in enumerate(examples_list): | |
if example.endswith('.xml'): | |
if idx % 50 == 0: | |
print('On image %d of %d' % (idx, len(examples_list))) | |
path = os.path.join(annotations_dir, example) | |
with tf.gfile.GFile(path, 'r') as fid: | |
xml_str = fid.read() | |
xml = etree.fromstring(xml_str) | |
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation'] | |
tf_example = dict_to_tf_example(data, image_dir, label_map_dict) | |
writer.write(tf_example.SerializeToString()) | |
writer.close() | |
if __name__ == '__main__': | |
tf.app.run() | |
# Import needed variables from tensorflow | |
# From tensorflow/models/research/ | |
#protoc object_detection/protos/*.proto --python_out=. | |
#export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim | |
#python object_detection/builders/model_builder_test.py |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment