Last active
September 2, 2019 07:19
-
-
Save dustindorroh/35d6b789b8af3a6b7dc631a5bca8a887 to your computer and use it in GitHub Desktop.
Creating keypoints tfrecords in using tensorflow's object_detection.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create_head_tf_record.py | |
# | |
# Created by Dustin Dorroh on 1/07/2019 | |
# | |
import hashlib | |
import io | |
import json | |
import logging | |
import os | |
import re | |
from itertools import chain | |
from pathlib import Path | |
import contextlib2 | |
import pandas as pd | |
import tensorflow as tf | |
from PIL import Image | |
from object_detection.dataset_tools import tf_record_creation_util | |
from object_detection.utils import dataset_util | |
from object_detection.utils import label_map_util | |
flags = tf.app.flags | |
flags.DEFINE_string('output_dir', '', 'Path to directory to output TFRecords.') | |
flags.DEFINE_string('label_map_path', 'data/face_label_map.pbtxt', | |
'Path to label map proto') | |
flags.DEFINE_integer('num_shards', 10, 'Number of TFRecord shards') | |
flags.DEFINE_string('csv_path', '', 'Path to examples file. A flat list of annotation files') | |
flags.DEFINE_string('dataset_name', '', 'Name of the dataset. It will be prefixed to the output tfrecords.') | |
FLAGS = flags.FLAGS | |
def pil_loader(path): | |
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
with open(path, 'rb') as f: | |
with Image.open(f) as img: | |
return img.convert('RGB') | |
def read_encode_jpg_buff(path): | |
with open(path, 'rb') as f: | |
with Image.open(f) as img: | |
buf = io.BytesIO() | |
if img.mode in ('RGBA', 'P', 'LA'): | |
img = img.convert('RGB') | |
img.save(buf, format='JPEG') | |
buf.seek(0) | |
return buf | |
def read_json(path): | |
with open(path, 'r') as f: | |
return json.load(f) | |
def get_class_name_from_filename(file_name): | |
"""Gets the class name from a file. | |
Args: | |
file_name: The file name to get the class name from. | |
ie. "american_pit_bull_terrier_105.jpg" | |
Returns: | |
A string of the class name. | |
""" | |
match = re.match(r'([A-Za-z_]+)(_[0-9]+\.jpg)', file_name, re.I) | |
return match.groups()[0] | |
def dict_to_tf_example(row, | |
label_map_dict, | |
# ignore_difficult_instances=False, | |
# normalized_bbox=True, | |
class_name='face', | |
num_landmarks=106): | |
"""Convert XML derived dict to tf.Example proto. | |
Notice that this function normalizes the bounding box coordinates provided | |
by the raw data. | |
Args: | |
row: pandas row holding all landmark labels for a face) | |
label_map_dict: A map from string label names to integers ids. | |
ignore_difficult_instances: Whether to skip difficult instances in the | |
dataset (default: False). | |
normalized_bbox: Whether to bbox coordinates are already normalized | |
dataset (default: True). | |
Returns: | |
example: The converted tf.Example. | |
Raises: | |
ValueError: if the image pointed to by row['filename'] is not a valid JPEG | |
""" | |
img_path = row.image_path | |
if Path(img_path).suffix.lower() == '.png': | |
encoded_jpg_io = read_encode_jpg_buff(img_path) | |
key = hashlib.sha256(encoded_jpg_io.read()).hexdigest() | |
encoded_jpg_io.seek(0) # Seek back because the hexdigest just read through the buffer | |
encoded_jpg = encoded_jpg_io.read() | |
image = Image.open(row.image_path) # Just using it for metadata we already encoded our jpeg version | |
else: | |
with tf.gfile.GFile(img_path, 'rb') as fid: | |
encoded_jpg = fid.read() | |
encoded_jpg_io = io.BytesIO(encoded_jpg) | |
image = Image.open(encoded_jpg_io) | |
if image.format != 'JPEG': | |
raise ValueError('Image format not JPEG') | |
key = hashlib.sha256(encoded_jpg).hexdigest() | |
width = image.width | |
height = image.height | |
x_columns = ['x{}'.format(i) for i in range(num_landmarks)] | |
y_columns = ['y{}'.format(i) for i in range(num_landmarks)] | |
point_columns = list(chain.from_iterable(zip(y_columns, x_columns))) | |
classes = [label_map_dict[class_name]] | |
classes_text = [class_name.encode('utf8')] | |
feature_dict = { | |
'image/height': dataset_util.int64_feature(height), | |
'image/width': dataset_util.int64_feature(width), | |
'image/filename': dataset_util.bytes_feature(row['image_path'].encode('utf8')), | |
'image/source_id': dataset_util.bytes_feature(row['image_path'].encode('utf8')), | |
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')), | |
'image/encoded': dataset_util.bytes_feature(encoded_jpg), | |
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')), | |
'image/object/bbox/xmin': dataset_util.float_list_feature([row.rekognition_normed_xmin]), | |
'image/object/bbox/xmax': dataset_util.float_list_feature([row.rekognition_normed_xmax]), | |
'image/object/bbox/ymin': dataset_util.float_list_feature([row.rekognition_normed_ymin]), | |
'image/object/bbox/ymax': dataset_util.float_list_feature([row.rekognition_normed_ymax]), | |
# 'groundtruth_landmarks/x': dataset_util.float_list_feature(row[x_columns]), | |
# 'groundtruth_landmarks/y': dataset_util.float_list_feature(row[y_columns]), | |
# groundtruth_instance_masks: ground truth instance masks. | |
# groundtruth_instance_boundaries: ground truth instance boundaries. | |
# groundtruth_instance_classes: instance mask-level class labels. | |
# groundtruth_keypoints: ground truth keypoints. | |
# groundtruth_keypoint_visibilities: ground truth keypoint visibilities. | |
# groundtruth_label_weights: groundtruth label weights. | |
# groundtruth_weights: groundtruth weight factor for bounding boxes. | |
# num_groundtruth_boxes: number of groundtruth boxes. | |
# is_annotated: whether an image has been labeled or not. | |
# 'groundtruth_keypoints': dataset_util.float_list_feature(row[point_columns].tolist()), | |
'image/object/keypoint/y': dataset_util.float_list_feature(row[y_columns].tolist()), | |
'image/object/keypoint/x': dataset_util.float_list_feature(row[x_columns].tolist()), | |
# 'image/object/pose/yaw': dataset_util.float_list_feature([row['Yaw']]), | |
# 'image/object/pose/pitch': dataset_util.float_list_feature([row['Pitch']]), | |
# 'image/object/pose/roll': dataset_util.float_list_feature([row['Roll']]), | |
'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | |
'image/object/class/label': dataset_util.int64_list_feature(classes), | |
'image/object/difficult': dataset_util.int64_list_feature([]), | |
'image/object/truncated': dataset_util.int64_list_feature([]), | |
'image/object/view': dataset_util.bytes_list_feature(''.encode('utf8')), | |
} | |
example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) | |
return example | |
def create_tf_record(output_filename, | |
num_shards, | |
label_map_dict, | |
df): | |
"""Creates a TFRecord file from examples. | |
Args: | |
output_filename: Path to where output file is saved. | |
num_shards: Number of shards for output file. | |
label_map_dict: The label map dictionary. | |
df: Examples to parse and save to tf record. | |
""" | |
with contextlib2.ExitStack() as tf_record_close_stack: | |
output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords( | |
tf_record_close_stack, output_filename, num_shards) | |
for idx, row in df.iterrows(): | |
if idx % 100 == 0: | |
logging.info('On image %d of %d', idx, len(df)) | |
try: | |
tf_example = dict_to_tf_example( | |
row, | |
label_map_dict) | |
if tf_example: | |
shard_idx = idx % num_shards | |
output_tfrecords[shard_idx].write(tf_example.SerializeToString()) | |
except ValueError: | |
logging.warning('Invalid example: %s, ignoring.', row.image_path) | |
def main(_): | |
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path) | |
logging.info('Reading from {} dataset.'.format(FLAGS.dataset_name)) | |
df = pd.read_csv(FLAGS.csv_path) | |
val_df = df.sample(frac=.1) | |
train_df = df.drop(index=val_df.index) | |
train_output_path = os.path.join(FLAGS.output_dir, '{}_train.record'.format(FLAGS.dataset_name)) | |
val_output_path = os.path.join(FLAGS.output_dir, '{}_val.record'.format(FLAGS.dataset_name)) | |
create_tf_record( | |
train_output_path, | |
FLAGS.num_shards, | |
label_map_dict, | |
train_df) | |
create_tf_record( | |
val_output_path, | |
FLAGS.num_shards // 10, | |
label_map_dict, | |
val_df) | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment