Last active
March 6, 2020 12:58
-
-
Save denisb411/c305090d9036ccff764eacff7b263ef1 to your computer and use it in GitHub Desktop.
Script used to facilitate the process of training of tensorflow object detection API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, sys | |
import shutil | |
from argparse import ArgumentParser | |
from random import randrange | |
from PIL import Image | |
import glob | |
import re | |
import xml.etree.ElementTree as ET | |
import pandas as pd | |
import tensorflow as tf | |
from google.protobuf import text_format | |
import time | |
import subprocess | |
try: | |
import object_detection | |
from object_detection.protos import pipeline_pb2 | |
import nets | |
except (ImportError, ModuleNotFoundError): | |
os.system('cd model/ && protoc object_detection/protos/*.proto --python_out=.') | |
os.system('cd model/ && pip install .') | |
os.system('cd model/slim && python setup.py install ') | |
from object_detection import pipeline_pb2 | |
os.environ["PYTHONPATH"] = f"{os.path.abspath('./model')};{os.path.abspath('./model/slim')}" | |
__author__ = "" | |
__date__ = "" | |
__version__ = "" | |
__description__ = "" | |
if __name__ == '__main__': | |
argParser = ArgumentParser(description=__description__, | |
epilog='Developed by ' + __author__ + ' in ' + __date__) | |
argParser.add_argument('--input-images', dest="input_images", type=str, required=True) | |
argParser.add_argument('--input-annotations', dest="input_annotations", type=str, required=True) | |
argParser.add_argument('--num-steps', dest="num_steps", type=int, default=40000) | |
argParser.add_argument('--pipeline-file', dest="pipeline_file", type=str, default='model/object_detection/samples/configs/faster_rcnn_inception_v2_pets.config') | |
argParser.add_argument('--config-weights-relation-file', dest="config_weights_relation_file", type=str, default='model/object_detection/samples/configs/faster_rcnn_inception_v2_pets.config') | |
argParser.add_argument('--continue', dest="continue_training", action='store_true') | |
args = argParser.parse_args() | |
ANNOTATIONS_PATH = os.path.abspath(args.input_annotations) | |
IMAGES_PATH = os.path.abspath(args.input_images) | |
OD_PATH = os.path.abspath('model/object_detection/') | |
IMAGES_TEST_PATH = os.path.abspath(OD_PATH + '/images/test/') | |
IMAGES_TRAIN_PATH = os.path.abspath(OD_PATH + '/images/train/') | |
CONF_FILES_PATH = './model/object_detection/samples/configs/' | |
WEIGHTS_PATH = './model/object_detection/weights/' | |
if args.continue_training: | |
os.system(f'cd {OD_PATH} && \ | |
python train.py \ | |
--logtostderr \ | |
--train_dir=training/ \ | |
--pipeline_config_path=training/faster_rcnn_inception_v2_pets.config --worker_replicas=2 --ps-tasks=1') | |
sys.exit(0) | |
def clean_folder(folder): | |
if os.path.isdir(folder): | |
shutil.rmtree(folder) | |
os.mkdir(folder) | |
## Clean old train/test sets and other files | |
clean_folder(IMAGES_TRAIN_PATH) | |
clean_folder(IMAGES_TEST_PATH) | |
## Clean training and inference_graph folders ## | |
clean_folder(OD_PATH + '/inference_graph') | |
clean_folder(OD_PATH + '/training') | |
clean_folder(f'{OD_PATH}/images') | |
os.mkdir(IMAGES_TEST_PATH) | |
os.mkdir(IMAGES_TRAIN_PATH) | |
if not os.path.isdir(WEIGHTS_PATH): | |
os.mkdir(WEIGHTS_PATH) | |
## Doing splits of 80% and convert images to jpg | |
for annot_xml in glob.glob(ANNOTATIONS_PATH + '/*.xml' ): | |
xml_file = annot_xml | |
image_name = annot_xml.replace('\\', '/').split('/')[-1].split('.')[0] | |
xml_name = annot_xml.replace('\\', '/').split('/')[-1] | |
try: | |
im = Image.open(f'{IMAGES_PATH}/{image_name}.bmp') | |
except FileNotFoundError: | |
try: | |
im = Image.open(f'{IMAGES_PATH}/{image_name}.jpg') | |
except FileNotFoundError: | |
print('FileNotFoundError:', image_name) | |
continue | |
if randrange(10) > 7: | |
shutil.copy(xml_file, IMAGES_TEST_PATH) | |
im.save(f'{IMAGES_TEST_PATH}/{image_name}.jpg') | |
else: | |
shutil.copy(xml_file, IMAGES_TRAIN_PATH) | |
im.save(f'{IMAGES_TRAIN_PATH}/{image_name}.jpg') | |
## fix annotations file path | |
def fix_filepath(annot_xml): | |
root = ET.parse(annot_xml) | |
image_name = annot_xml.replace('\\', '/').split('/')[-1].split('.')[0] + '.jpg' | |
image_path = '/'.join(os.path.abspath(annot_xml).split('/')[:-1]) + '/' + image_name | |
root.find('filename').text = image_name | |
root.find('path').text = image_path | |
root.write(open(annot_xml, 'wb')) | |
for annot_xml in glob.glob(IMAGES_TEST_PATH + '/*.xml'): | |
fix_filepath(annot_xml) | |
for annot_xml in glob.glob(IMAGES_TRAIN_PATH + '/*.xml'): | |
fix_filepath(annot_xml) | |
os.system(f'cd {OD_PATH} && python ./xml_to_csv.py') | |
## write labels pbtxt file | |
file = pd.read_csv(f'{OD_PATH}/images/train_labels.csv') | |
categories = file['class'].unique() | |
end = '\n' | |
s = ' ' | |
class_map = {} | |
for ID, name in enumerate(categories): | |
out = '' | |
out += 'item' + s + '{' + end | |
out += s*2 + 'id:' + ' ' + (str(ID+1)) + end | |
out += s*2 + 'name:' + ' ' + '\'' + name + '\'' + end | |
out += '}' + end*2 | |
with open(f'{OD_PATH}/training/labelmap.pbtxt', 'a') as f: | |
f.write(out) | |
class_map[name] = ID+1 | |
os.system(f"cd {OD_PATH} && \ | |
python generate_tfrecord.py \ | |
--csv_input=images/train_labels.csv \ | |
--image_dir=images/train \ | |
--output_path=train.record \ | |
--classes {' '.join(file['class'].unique())}") | |
os.system(f"cd {OD_PATH} && \ | |
python generate_tfrecord.py \ | |
--csv_input=images/test_labels.csv \ | |
--image_dir=images/test \ | |
--output_path=test.record \ | |
--classes {' '.join(file['class'].unique())}") | |
if args.config_weights_relation_file: | |
config_relations_df = pd.read_csv(args.config_weights_relation_file, header=None) | |
all_logs_dir = os.path.abspath('./train-logs') | |
if not os.path.isdir(all_logs_dir): | |
os.mkdir(all_logs_dir) | |
current_logs_dir_name = time.strftime("%Y%m%d-%H%M%S") | |
current_logs_dir = os.path.abspath(f'{all_logs_dir}/{current_logs_dir_name}') | |
os.mkdir(current_logs_dir) | |
for idx, row in config_relations_df.iterrows(): | |
pipeline_file = CONF_FILES_PATH + '/' + row[0] | |
pipeline_file_name = pipeline_file.split('/')[-1].split('.')[0] | |
url_download_weight = row[1] | |
weight_file_name = '.'.join(row[1].split('/')[-1].split('.')[:-2]) | |
weight_file = os.path.join(WEIGHTS_PATH, row[1].split('/')[-1]) | |
weight_path = os.path.join(WEIGHTS_PATH, weight_file_name) | |
if not os.path.isdir(weight_path): | |
if not os.path.isfile(weight_file): | |
res = subprocess.run(f'wget -O {weight_file} {url_download_weight}', shell=True) | |
res = subprocess.run(f'tar -xzf {weight_file} -C {WEIGHTS_PATH} && rm {weight_file}', shell=True) | |
train_path = os.path.abspath(f'{OD_PATH}/training/{weight_file_name}-{pipeline_file_name}') | |
if os.path.isdir(train_path): | |
shutil.rmtree(train_path) | |
os.mkdir(train_path) | |
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() | |
with tf.gfile.GFile(pipeline_file, "r") as f: | |
proto_str = f.read() | |
text_format.Merge(proto_str, pipeline_config) | |
if 'ssd' in pipeline_file: | |
pipeline_config.model.ssd.num_classes = len(file['class'].unique()) | |
else: | |
pipeline_config.model.faster_rcnn.num_classes = len(file['class'].unique()) | |
pipeline_config.train_config.fine_tune_checkpoint = os.path.abspath(f'{weight_path}/model.ckpt') | |
pipeline_config.train_input_reader.tf_record_input_reader.input_path[:] = [os.path.abspath(f'{OD_PATH}/train.record')] | |
pipeline_config.train_input_reader.label_map_path = os.path.abspath(f'{OD_PATH}/training/labelmap.pbtxt') | |
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[:] = [os.path.abspath(f'{OD_PATH}/test.record')] | |
pipeline_config.eval_input_reader[0].label_map_path = os.path.abspath(f'{OD_PATH}/training/labelmap.pbtxt') | |
pipeline_config.train_config.num_steps = args.num_steps | |
config_text = text_format.MessageToString(pipeline_config) | |
pipeline_file_name = pipeline_file.replace('\\', '/').split('/')[-1] | |
pipeline_file_path = f'{train_path}/{pipeline_file_name}' | |
def write_pipeline_conf(): | |
with tf.gfile.Open(f"{pipeline_file_path}", "wb") as f: | |
f.write(config_text) | |
write_pipeline_conf() | |
current_train_dir = os.path.abspath(f'{current_logs_dir}/{weight_file_name}-{pipeline_file_name}') | |
os.mkdir(current_train_dir) | |
train_logs_file = os.path.abspath(f"{current_train_dir}/logs.txt") | |
print(f"Training model {weight_file_name} using configs {pipeline_file_name}") | |
def start_training(f): | |
return subprocess.run(["python", f"{OD_PATH}/train.py", | |
"--train_dir", train_path, | |
"--pipeline_config_path", pipeline_file_path], | |
stdout=f, stderr=f) | |
with open(train_logs_file, "wb") as f: | |
proc_result = start_training(f) | |
if proc_result.returncode != 0: | |
pipeline_config.train_config.from_detection_checkpoint = False | |
proc_result = start_training(f) | |
if proc_result.returncode != 0: | |
pipeline_config.train_config.from_detection_checkpoint = True | |
proc_result = start_training(f) | |
if proc_result.returncode != 0: | |
pipeline_config.train_config.fine_tune_checkpoint_type = "detection" | |
proc_result = start_training(f) | |
with open(f"{current_logs_dir}/general-train-results.txt", "a") as f: | |
if proc_result.returncode != 0: | |
f.write(f"problem during train of model {weight_file_name} using configs {pipeline_file_name}\n") | |
else: | |
f.write(f"succesfully trained model {weight_file_name} using configs {pipeline_file_name}\n") | |
else: | |
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() | |
with tf.gfile.GFile(pipeline_file, "r") as f: | |
proto_str = f.read() | |
text_format.Merge(proto_str, pipeline_config) | |
pipeline_config.model.faster_rcnn.num_classes = len(file['class'].unique()) | |
pipeline_config.train_config.fine_tune_checkpoint = os.path.abspath(f'{OD_PATH}/faster_rcnn_inception_v2_coco_2018_01_28/model.ckpt') | |
pipeline_config.train_input_reader.tf_record_input_reader.input_path[:] = [os.path.abspath(f'{OD_PATH}/train.record')] | |
pipeline_config.train_input_reader.label_map_path = os.path.abspath(f'{OD_PATH}/training/labelmap.pbtxt') | |
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[:] = [os.path.abspath(f'{OD_PATH}/test.record')] | |
pipeline_config.eval_input_reader[0].label_map_path = os.path.abspath(f'{OD_PATH}/training/labelmap.pbtxt') | |
pipeline_config.train_config.num_steps = args.num_steps | |
config_text = text_format.MessageToString(pipeline_config) | |
pipeline_file_name = args.pipeline_file.replace('\\', '/').split('/')[-1] | |
with tf.gfile.Open(f"{OD_PATH}/training/{pipeline_file_name}", "wb") as f: | |
f.write(config_text) | |
os.system(f'cd {OD_PATH} && \ | |
python train.py \ | |
--logtostderr \ | |
--train_dir=training/ \ | |
--pipeline_config_path=training/faster_rcnn_inception_v2_pets.config --worker_replicas=2 --ps-tasks=1') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment