Created
April 20, 2019 01:34
-
-
Save FreeFly19/07a8472678727c448fe56cf9634e3c31 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Usage: | |
# From tensorflow/models/ | |
# Create train data: | |
python generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.record | |
# Create test data: | |
python generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.record | |
""" | |
from __future__ import division | |
from __future__ import print_function | |
from __future__ import absolute_import | |
import os | |
import io | |
import pandas as pd | |
import tensorflow as tf | |
from PIL import Image | |
from object_detection.utils import dataset_util | |
from collections import namedtuple, OrderedDict | |
flags = tf.app.flags | |
flags.DEFINE_string('csv_input', '', 'Path to the CSV input') | |
flags.DEFINE_string('output_path', '', 'Path to output TFRecord') | |
flags.DEFINE_string('image_dir', '', 'Path to images') | |
FLAGS = flags.FLAGS | |
# TO-DO replace this with label map | |
def class_text_to_int(row_label): | |
classes = { | |
'D00': 1, | |
'D01': 2, | |
'D10': 3, | |
'D11': 4, | |
'D20': 5, | |
'D30': 6, | |
'D40': 7, | |
'D43': 8, | |
'D44': 9 | |
} | |
return classes[row_label] | |
def split(df, group): | |
data = namedtuple('data', ['filename', 'object']) | |
gb = df.groupby(group) | |
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] | |
def create_tf_example(group, path): | |
with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: | |
encoded_jpg = fid.read() | |
encoded_jpg_io = io.BytesIO(encoded_jpg) | |
image = Image.open(encoded_jpg_io) | |
width, height = image.size | |
filename = group.filename.encode('utf8') | |
image_format = b'jpg' | |
xmins = [] | |
xmaxs = [] | |
ymins = [] | |
ymaxs = [] | |
classes_text = [] | |
classes = [] | |
for index, row in group.object.iterrows(): | |
xmins.append(row['xmin'] / width) | |
xmaxs.append(row['xmax'] / width) | |
ymins.append(row['ymin'] / height) | |
ymaxs.append(row['ymax'] / height) | |
classes_text.append(row['class'].encode('utf8')) | |
classes.append(class_text_to_int(row['class'])) | |
tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
'image/height': dataset_util.int64_feature(height), | |
'image/width': dataset_util.int64_feature(width), | |
'image/filename': dataset_util.bytes_feature(filename), | |
'image/source_id': dataset_util.bytes_feature(filename), | |
'image/encoded': dataset_util.bytes_feature(encoded_jpg), | |
'image/format': dataset_util.bytes_feature(image_format), | |
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), | |
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), | |
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), | |
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), | |
'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | |
'image/object/class/label': dataset_util.int64_list_feature(classes), | |
})) | |
return tf_example | |
def main(_): | |
writer = tf.python_io.TFRecordWriter(FLAGS.output_path) | |
path = os.path.join(FLAGS.image_dir) | |
examples = pd.read_csv(FLAGS.csv_input) | |
grouped = split(examples, 'filename') | |
for group in grouped: | |
tf_example = create_tf_example(group, path) | |
writer.write(tf_example.SerializeToString()) | |
writer.close() | |
output_path = os.path.join(os.getcwd(), FLAGS.output_path) | |
print('Successfully created the TFRecords: {}'.format(output_path)) | |
if __name__ == '__main__': | |
tf.app.run() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
export PYTHONPATH=$PYTHONPATH:/home/freefly19/projects/tf-models/research:/home/freefly19/projects/tf-models/research/slim:/home/freefly19/projects/tf-models/research/object_detection:/home/freefly19/projects/tf-models/research/object_detection/utils | |
python split.py | |
python xml_to_csv.py | |
cd train | |
python ../generate_tfrecord.py --csv_input=../data/train_labels.csv --output_path=../data/train.record | |
cd ../test | |
python ../generate_tfrecord.py --csv_input=../data/test_labels.csv --output_path=../data/test.record | |
cd .. | |
python /home/freefly19/projects/tf-models/research/object_detection/model_main.py --logtostderr --pipeline_config_path=training/ssd_mobilenet_v1_coco.config --model_dir=training | |
tensorboard --logdir=training-ssd-mobilenet_300_300 | |
python export_inference_graph.py \ | |
--input_type image_tensor \ | |
--pipeline_config_path training-3/faster_rcnn_inception_resnet_v2_atrous_coco.config \ | |
--trained_checkpoint_prefix training-3/model.ckpt-200000 \ | |
--output_directory training-3-3output | |
sensors | |
nvidia-smi --query-gpu=temperature.gpu --format=csv,noheader | |
----- |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from random import shuffle | |
from math import floor | |
import shutil | |
def get_training_and_testing_sets(file_list): | |
split = 0.7 | |
split_index = int(floor(len(file_list) * split)) | |
training = file_list[:split_index] | |
testing = file_list[split_index:] | |
return training, testing | |
def get_file_list_from_dir(datadir): | |
all_files = os.listdir(os.path.abspath(datadir)) | |
data_files = list(map(lambda file: file.replace('.jpg', ''), filter(lambda file: file.endswith('.jpg'), all_files))) | |
return data_files | |
regions = ['Chiba', 'Sumida', 'Adachi', 'Ichihara', 'Nagakute', 'Muroran', 'Numazu'] | |
for r in regions: | |
file_names = get_file_list_from_dir(r + '/JPEGImages') | |
shuffle(file_names) | |
training, test = get_training_and_testing_sets(file_names) | |
if not os.path.exists('train'): | |
os.makedirs('train') | |
if not os.path.exists('test'): | |
os.makedirs('test') | |
for name in training: | |
if os.path.exists(r + '/Annotations/' + name + '.xml'): | |
shutil.copy(r + '/JPEGImages/' + name + '.jpg', 'train/' + name + '.jpg') | |
shutil.copy(r + '/Annotations/' + name + '.xml', 'train/' + name + '.xml') | |
for name in test: | |
if os.path.exists(r + '/Annotations/' + name + '.xml'): | |
shutil.copy(r + '/JPEGImages/' + name + '.jpg', 'test/' + name + '.jpg') | |
shutil.copy(r + '/Annotations/' + name + '.xml', 'test/' + name + '.xml') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import glob | |
import pandas as pd | |
import xml.etree.ElementTree as ET | |
def xml_to_csv(path): | |
xml_list = [] | |
for xml_file in glob.glob(path + '/*.xml'): | |
tree = ET.parse(xml_file) | |
root = tree.getroot() | |
for member in root.findall('object'): | |
value = (root.find('filename').text, | |
int(float(root.find('size')[0].text)), | |
int(float(root.find('size')[1].text)), | |
member.find('name').text, | |
int(float(member.find('bndbox').find('xmin').text)), | |
int(float(member.find('bndbox').find('ymin').text)), | |
int(float(member.find('bndbox').find('xmax').text)), | |
int(float(member.find('bndbox').find('ymax').text)) | |
) | |
xml_list.append(value) | |
column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax'] | |
xml_df = pd.DataFrame(xml_list, columns=column_name) | |
return xml_df | |
def main(): | |
if not os.path.exists('data'): | |
os.makedirs('data') | |
for p in ['test', 'train']: | |
image_path = os.path.join(os.getcwd(), p) | |
xml_df = xml_to_csv(image_path) | |
xml_df.to_csv('data/{}_labels.csv'.format(p), index=None) | |
print('Successfully converted xml to csv.') | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment