Last active
April 30, 2019 18:02
-
-
Save yongjun823/178afe3a00e0031b07022bf1b9ceb26e to your computer and use it in GitHub Desktop.
tensorflow object detection api) tfrecord create code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import base64 | |
import csv | |
import os | |
from PIL import Image | |
from utils import dataset_util | |
""" | |
csv format | |
xmin, ymin, xmax, ymax, class | |
https://github.com/tensorflow/models/blob/master/object_detection/g3doc/using_your_own_dataset.md | |
""" | |
class_dict = { | |
1: b'bird', # List of class map Text with byte | |
2: b'tiger' | |
} | |
def create_tf_example(img_name, img_path, bbox_data): | |
im = Image.open(img_path) | |
height = im.height # Image height | |
width = im.width # Image width | |
filename = str.encode(img_name) # Filename of the image. Empty if image is not from file | |
img_file = open(img_path, 'rb') | |
encoded_image_data = base64.b64encode(img_file.read()) # Encoded image bytes | |
image_format = str.encode(img_name.split('.')[1]) # b'jpeg' or b'png' | |
xmins = [bbox_data[0] / width] # List of normalized left x coordinates in bounding box (1 per box) | |
xmaxs = [bbox_data[2] / width] # List of normalized right x coordinates in bounding box | |
# (1 per box) | |
ymins = [bbox_data[1] / height] # List of normalized top y coordinates in bounding box (1 per box) | |
ymaxs = [bbox_data[3] / height] # List of normalized bottom y coordinates in bounding box | |
# (1 per box) | |
classes = [bbox_data[4]] # List of integer class id of bounding box (1 per box) | |
classes_text = [class_dict[bbox_data[4]]] # List of string class name of bounding box (1 per box) !!csv data!! | |
tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
'image/height': dataset_util.int64_feature(height), | |
'image/width': dataset_util.int64_feature(width), | |
'image/filename': dataset_util.bytes_feature(filename), | |
'image/source_id': dataset_util.bytes_feature(filename), | |
'image/encoded': dataset_util.bytes_feature(encoded_image_data), | |
'image/format': dataset_util.bytes_feature(image_format), | |
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), | |
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), | |
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), | |
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), | |
'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | |
'image/object/class/label': dataset_util.int64_list_feature(classes), | |
})) | |
return tf_example | |
def main(_): | |
writer = tf.python_io.TFRecordWriter('./tfrecord/data.tfrecords') | |
# TODO(user): Write code to read in your dataset to examples variable | |
data_path = './mydata' | |
images = os.listdir(data_path) | |
label_csv = './bbox.csv' | |
csv_reader = csv.reader(open(label_csv, 'r')) | |
for data, label in zip(images, csv_reader): | |
label = [int(l) for l in label] | |
img_path = data_path + '/' + data | |
tf_example = create_tf_example(data, img_path, label) | |
writer.write(tf_example.SerializeToString()) | |
# for example in examples: | |
# tf_example = create_tf_example(1) | |
# writer.write(tf_example.SerializeToString()) | |
writer.close() | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
If I use
encoded_image_data = base64.b64encode(img_file.read()) # Encoded image bytes
to encode the image, I will get an errorwhen I try to run train.py.
But the following code works for me:
Thanks to https://stackoverflow.com/questions/46687348/decoding-tfrecord-with-tfslim.