Created
August 29, 2022 16:59
-
-
Save arm2arm/fe8fa9b1b9f6f716f914c0d62922bfa9 to your computer and use it in GitHub Desktop.
VOC to rec
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Sample TensorFlow XML-to-TFRecord converter | |
usage: generate_tfrecord.py [-h] [-x XML_DIR] [-l LABELS_PATH] [-o OUTPUT_PATH] [-i IMAGE_DIR] [-c CSV_PATH] | |
optional arguments: | |
-h, --help show this help message and exit | |
-x XML_DIR, --xml_dir XML_DIR | |
Path to the folder where the input .xml files are stored. | |
-l LABELS_PATH, --labels_path LABELS_PATH | |
Path to the labels (.pbtxt) file. | |
-o OUTPUT_PATH, --output_path OUTPUT_PATH | |
Path of output TFRecord (.record) file. | |
-i IMAGE_DIR, --image_dir IMAGE_DIR | |
Path to the folder where the input image files are stored. Defaults to the same directory as XML_DIR. | |
-c CSV_PATH, --csv_path CSV_PATH | |
Path of output .csv file. If none provided, then no file will be written. | |
""" | |
import os | |
import glob | |
import pandas as pd | |
import io | |
import xml.etree.ElementTree as ET | |
import argparse | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging (1) | |
import tensorflow.compat.v1 as tf | |
from PIL import Image | |
from object_detection.utils import dataset_util, label_map_util | |
from collections import namedtuple | |
# Initiate argument parser | |
parser = argparse.ArgumentParser( | |
description="Sample TensorFlow XML-to-TFRecord converter") | |
parser.add_argument("-x", | |
"--xml_dir", | |
help="Path to the folder where the input .xml files are stored.", | |
type=str) | |
parser.add_argument("-l", | |
"--labels_path", | |
help="Path to the labels (.pbtxt) file.", type=str) | |
parser.add_argument("-o", | |
"--output_path", | |
help="Path of output TFRecord (.record) file.", type=str) | |
parser.add_argument("-i", | |
"--image_dir", | |
help="Path to the folder where the input image files are stored. " | |
"Defaults to the same directory as XML_DIR.", | |
type=str, default=None) | |
parser.add_argument("-c", | |
"--csv_path", | |
help="Path of output .csv file. If none provided, then no file will be " | |
"written.", | |
type=str, default=None) | |
args = parser.parse_args() | |
if args.image_dir is None: | |
args.image_dir = args.xml_dir | |
label_map = label_map_util.load_labelmap(args.labels_path) | |
label_map_dict = label_map_util.get_label_map_dict(label_map) | |
def xml_to_csv(path): | |
"""Iterates through all .xml files (generated by labelImg) in a given directory and combines | |
them in a single Pandas dataframe. | |
Parameters: | |
---------- | |
path : str | |
The path containing the .xml files | |
Returns | |
------- | |
Pandas DataFrame | |
The produced dataframe | |
""" | |
xml_list = [] | |
which_elem = 4 | |
for xml_file in glob.glob(path + '/*.xml'): | |
tree = ET.parse(xml_file) | |
root = tree.getroot() | |
for member in root.findall('object'): | |
value = (root.find('filename').text, | |
int(root.find('size')[0].text), | |
int(root.find('size')[1].text), | |
member[0].text, | |
int(member[which_elem][0].text), | |
int(member[which_elem][1].text), | |
int(member[which_elem][2].text), | |
int(member[which_elem][3].text) | |
) | |
xml_list.append(value) | |
column_name = ['filename', 'width', 'height', | |
'class', 'xmin', 'ymin', 'xmax', 'ymax'] | |
xml_df = pd.DataFrame(xml_list, columns=column_name) | |
return xml_df | |
def class_text_to_int(row_label): | |
return label_map_dict[row_label] | |
def split(df, group): | |
data = namedtuple('data', ['filename', 'object']) | |
gb = df.groupby(group) | |
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] | |
def create_tf_example(group, path): | |
with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: | |
encoded_jpg = fid.read() | |
encoded_jpg_io = io.BytesIO(encoded_jpg) | |
image = Image.open(encoded_jpg_io) | |
width, height = image.size | |
filename = group.filename.encode('utf8') | |
image_format = b'jpg' | |
xmins = [] | |
xmaxs = [] | |
ymins = [] | |
ymaxs = [] | |
classes_text = [] | |
classes = [] | |
for index, row in group.object.iterrows(): | |
xmins.append(row['xmin'] / width) | |
xmaxs.append(row['xmax'] / width) | |
ymins.append(row['ymin'] / height) | |
ymaxs.append(row['ymax'] / height) | |
classes_text.append(row['class'].encode('utf8')) | |
classes.append(class_text_to_int(row['class'])) | |
tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
'image/height': dataset_util.int64_feature(height), | |
'image/width': dataset_util.int64_feature(width), | |
'image/filename': dataset_util.bytes_feature(filename), | |
'image/source_id': dataset_util.bytes_feature(filename), | |
'image/encoded': dataset_util.bytes_feature(encoded_jpg), | |
'image/format': dataset_util.bytes_feature(image_format), | |
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), | |
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), | |
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), | |
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), | |
'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | |
'image/object/class/label': dataset_util.int64_list_feature(classes), | |
})) | |
return tf_example | |
def main(_): | |
writer = tf.python_io.TFRecordWriter(args.output_path) | |
path = os.path.join(args.image_dir) | |
examples = xml_to_csv(args.xml_dir) | |
grouped = split(examples, 'filename') | |
for group in grouped: | |
tf_example = create_tf_example(group, path) | |
writer.write(tf_example.SerializeToString()) | |
writer.close() | |
print('Successfully created the TFRecord file: {}'.format(args.output_path)) | |
if args.csv_path is not None: | |
examples.to_csv(args.csv_path, index=None) | |
print('Successfully created the CSV file: {}'.format(args.csv_path)) | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment