Created
October 4, 2020 13:48
-
-
Save seatedro/4862f3096affd70f1105a547c0264a70 to your computer and use it in GitHub Desktop.
Create TF Record
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import csv | |
import re | |
import cv2 | |
import os | |
import glob | |
import xml.etree.ElementTree as ET | |
import io | |
import tensorflow as tf | |
from collections import namedtuple, OrderedDict | |
import shutil | |
import urllib.request | |
import tarfile | |
import argparse | |
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging (1) | |
import tensorflow.compat.v1 as tf | |
from PIL import Image | |
from object_detection.utils import dataset_util, label_map_util | |
from collections import namedtuple | |
xml_dir = 'images/test' | |
image_dir = 'images/test' | |
label_map = label_map_util.load_labelmap('annotations/label_map.pbtxt') | |
label_map_dict = label_map_util.get_label_map_dict(label_map) | |
output_path = 'annotations/test.record' | |
def xml_to_csv(path): | |
"""Iterates through all .xml files (generated by labelImg) in a given directory and combines | |
them in a single Pandas dataframe. | |
Parameters: | |
---------- | |
path : str | |
The path containing the .xml files | |
Returns | |
------- | |
Pandas DataFrame | |
The produced dataframe | |
""" | |
xml_list = [] | |
for xml_file in glob.glob(path + '/*.xml'): | |
tree = ET.parse(xml_file) | |
root = tree.getroot() | |
for member in root.findall('object'): | |
value = (root.find('filename').text, | |
int(root.find('size')[0].text), | |
int(root.find('size')[1].text), | |
member[0].text, | |
int(member[4][0].text), | |
int(member[4][1].text), | |
int(member[4][2].text), | |
int(member[4][3].text) | |
) | |
xml_list.append(value) | |
column_name = ['filename', 'width', 'height', | |
'class', 'xmin', 'ymin', 'xmax', 'ymax'] | |
xml_df = pd.DataFrame(xml_list, columns=column_name) | |
return xml_df | |
def class_text_to_int(row_label): | |
return label_map_dict[row_label] | |
def split(df, group): | |
data = namedtuple('data', ['filename', 'object']) | |
gb = df.groupby(group) | |
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] | |
def create_tf_example(group, path): | |
with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: | |
encoded_jpg = fid.read() | |
encoded_jpg_io = io.BytesIO(encoded_jpg) | |
image = Image.open(encoded_jpg_io) | |
width, height = image.size | |
filename = group.filename.encode('utf8') | |
image_format = b'jpg' | |
xmins = [] | |
xmaxs = [] | |
ymins = [] | |
ymaxs = [] | |
classes_text = [] | |
classes = [] | |
for index, row in group.object.iterrows(): | |
xmins.append(row['xmin'] / width) | |
xmaxs.append(row['xmax'] / width) | |
ymins.append(row['ymin'] / height) | |
ymaxs.append(row['ymax'] / height) | |
classes_text.append(row['class'].encode('utf8')) | |
classes.append(class_text_to_int(row['class'])) | |
tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
'image/height': dataset_util.int64_feature(height), | |
'image/width': dataset_util.int64_feature(width), | |
'image/filename': dataset_util.bytes_feature(filename), | |
'image/source_id': dataset_util.bytes_feature(filename), | |
'image/encoded': dataset_util.bytes_feature(encoded_jpg), | |
'image/format': dataset_util.bytes_feature(image_format), | |
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), | |
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), | |
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), | |
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), | |
'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | |
'image/object/class/label': dataset_util.int64_list_feature(classes), | |
})) | |
return tf_example | |
csv_path = None | |
def main(_): | |
writer = tf.python_io.TFRecordWriter(output_path) | |
path = os.path.join(image_dir) | |
examples = xml_to_csv(xml_dir) | |
grouped = split(examples, 'filename') | |
for group in grouped: | |
tf_example = create_tf_example(group, path) | |
writer.write(tf_example.SerializeToString()) | |
writer.close() | |
print('Successfully created the TFRecord file: {}'.format(output_path)) | |
if csv_path is not None: | |
examples.to_csv(csv_path, index=None) | |
print('Successfully created the CSV file: {}'.format(csv_path)) | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment