Skip to content

Instantly share code, notes, and snippets.

@dustindorroh
Last active September 2, 2019 07:19
Show Gist options
  • Save dustindorroh/35d6b789b8af3a6b7dc631a5bca8a887 to your computer and use it in GitHub Desktop.
Save dustindorroh/35d6b789b8af3a6b7dc631a5bca8a887 to your computer and use it in GitHub Desktop.
Creating keypoints tfrecords in using tensorflow's object_detection.
# create_head_tf_record.py
#
# Created by Dustin Dorroh on 1/07/2019
#
import hashlib
import io
import json
import logging
import os
import re
from itertools import chain
from pathlib import Path
import contextlib2
import pandas as pd
import tensorflow as tf
from PIL import Image
from object_detection.dataset_tools import tf_record_creation_util
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
flags = tf.app.flags
flags.DEFINE_string('output_dir', '', 'Path to directory to output TFRecords.')
flags.DEFINE_string('label_map_path', 'data/face_label_map.pbtxt',
'Path to label map proto')
flags.DEFINE_integer('num_shards', 10, 'Number of TFRecord shards')
flags.DEFINE_string('csv_path', '', 'Path to examples file. A flat list of annotation files')
flags.DEFINE_string('dataset_name', '', 'Name of the dataset. It will be prefixed to the output tfrecords.')
FLAGS = flags.FLAGS
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def read_encode_jpg_buff(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
buf = io.BytesIO()
if img.mode in ('RGBA', 'P', 'LA'):
img = img.convert('RGB')
img.save(buf, format='JPEG')
buf.seek(0)
return buf
def read_json(path):
with open(path, 'r') as f:
return json.load(f)
def get_class_name_from_filename(file_name):
"""Gets the class name from a file.
Args:
file_name: The file name to get the class name from.
ie. "american_pit_bull_terrier_105.jpg"
Returns:
A string of the class name.
"""
match = re.match(r'([A-Za-z_]+)(_[0-9]+\.jpg)', file_name, re.I)
return match.groups()[0]
def dict_to_tf_example(row,
label_map_dict,
# ignore_difficult_instances=False,
# normalized_bbox=True,
class_name='face',
num_landmarks=106):
"""Convert XML derived dict to tf.Example proto.
Notice that this function normalizes the bounding box coordinates provided
by the raw data.
Args:
row: pandas row holding all landmark labels for a face)
label_map_dict: A map from string label names to integers ids.
ignore_difficult_instances: Whether to skip difficult instances in the
dataset (default: False).
normalized_bbox: Whether to bbox coordinates are already normalized
dataset (default: True).
Returns:
example: The converted tf.Example.
Raises:
ValueError: if the image pointed to by row['filename'] is not a valid JPEG
"""
img_path = row.image_path
if Path(img_path).suffix.lower() == '.png':
encoded_jpg_io = read_encode_jpg_buff(img_path)
key = hashlib.sha256(encoded_jpg_io.read()).hexdigest()
encoded_jpg_io.seek(0) # Seek back because the hexdigest just read through the buffer
encoded_jpg = encoded_jpg_io.read()
image = Image.open(row.image_path) # Just using it for metadata we already encoded our jpeg version
else:
with tf.gfile.GFile(img_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
key = hashlib.sha256(encoded_jpg).hexdigest()
width = image.width
height = image.height
x_columns = ['x{}'.format(i) for i in range(num_landmarks)]
y_columns = ['y{}'.format(i) for i in range(num_landmarks)]
point_columns = list(chain.from_iterable(zip(y_columns, x_columns)))
classes = [label_map_dict[class_name]]
classes_text = [class_name.encode('utf8')]
feature_dict = {
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(row['image_path'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(row['image_path'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature([row.rekognition_normed_xmin]),
'image/object/bbox/xmax': dataset_util.float_list_feature([row.rekognition_normed_xmax]),
'image/object/bbox/ymin': dataset_util.float_list_feature([row.rekognition_normed_ymin]),
'image/object/bbox/ymax': dataset_util.float_list_feature([row.rekognition_normed_ymax]),
# 'groundtruth_landmarks/x': dataset_util.float_list_feature(row[x_columns]),
# 'groundtruth_landmarks/y': dataset_util.float_list_feature(row[y_columns]),
# groundtruth_instance_masks: ground truth instance masks.
# groundtruth_instance_boundaries: ground truth instance boundaries.
# groundtruth_instance_classes: instance mask-level class labels.
# groundtruth_keypoints: ground truth keypoints.
# groundtruth_keypoint_visibilities: ground truth keypoint visibilities.
# groundtruth_label_weights: groundtruth label weights.
# groundtruth_weights: groundtruth weight factor for bounding boxes.
# num_groundtruth_boxes: number of groundtruth boxes.
# is_annotated: whether an image has been labeled or not.
# 'groundtruth_keypoints': dataset_util.float_list_feature(row[point_columns].tolist()),
'image/object/keypoint/y': dataset_util.float_list_feature(row[y_columns].tolist()),
'image/object/keypoint/x': dataset_util.float_list_feature(row[x_columns].tolist()),
# 'image/object/pose/yaw': dataset_util.float_list_feature([row['Yaw']]),
# 'image/object/pose/pitch': dataset_util.float_list_feature([row['Pitch']]),
# 'image/object/pose/roll': dataset_util.float_list_feature([row['Roll']]),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature([]),
'image/object/truncated': dataset_util.int64_list_feature([]),
'image/object/view': dataset_util.bytes_list_feature(''.encode('utf8')),
}
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
return example
def create_tf_record(output_filename,
num_shards,
label_map_dict,
df):
"""Creates a TFRecord file from examples.
Args:
output_filename: Path to where output file is saved.
num_shards: Number of shards for output file.
label_map_dict: The label map dictionary.
df: Examples to parse and save to tf record.
"""
with contextlib2.ExitStack() as tf_record_close_stack:
output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
tf_record_close_stack, output_filename, num_shards)
for idx, row in df.iterrows():
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(df))
try:
tf_example = dict_to_tf_example(
row,
label_map_dict)
if tf_example:
shard_idx = idx % num_shards
output_tfrecords[shard_idx].write(tf_example.SerializeToString())
except ValueError:
logging.warning('Invalid example: %s, ignoring.', row.image_path)
def main(_):
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
logging.info('Reading from {} dataset.'.format(FLAGS.dataset_name))
df = pd.read_csv(FLAGS.csv_path)
val_df = df.sample(frac=.1)
train_df = df.drop(index=val_df.index)
train_output_path = os.path.join(FLAGS.output_dir, '{}_train.record'.format(FLAGS.dataset_name))
val_output_path = os.path.join(FLAGS.output_dir, '{}_val.record'.format(FLAGS.dataset_name))
create_tf_record(
train_output_path,
FLAGS.num_shards,
label_map_dict,
train_df)
create_tf_record(
val_output_path,
FLAGS.num_shards // 10,
label_map_dict,
val_df)
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment