Created
October 4, 2018 16:57
-
-
Save mikaelhg/9e7d303f3ff68c23ce961e1875528a1d to your computer and use it in GitHub Desktop.
Minimal Tensorflow object detection example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import warnings | |
with warnings.catch_warnings(): | |
warnings.filterwarnings('ignore', category=FutureWarning) | |
import h5py | |
import os | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' | |
import tensorflow as tf | |
import queue | |
from threading import Thread, Event | |
_OUTPUT_TENSOR_NAMES = ['detection_boxes', 'detection_scores', 'num_detections', 'detection_classes'] | |
_INPUT_TENSOR_NAME = 'image_tensor' | |
class DetectionThread(Thread): | |
stop_thread = Event() | |
def __init__(self, model: str, cpu: bool, image_q: queue.Queue, detection_q: queue.Queue): | |
super(DetectionThread, self).__init__() | |
self.model = model | |
self.cpu = cpu | |
self.image_q = image_q | |
self.detection_q = detection_q | |
def run(self): | |
try: | |
graph = tf.Graph() | |
print('M Loading graph...') | |
with open(self.model, 'rb') as f: | |
graph_def = tf.GraphDef.FromString(f.read()) | |
print('M Loaded graph.') | |
with graph.as_default(): | |
tf.import_graph_def(graph_def, name='') | |
config = tf.ConfigProto() | |
if self.cpu: | |
config.device_count['GPU'] = 0 | |
else: | |
config.gpu_options.per_process_gpu_memory_fraction = 0.2 | |
config.gpu_options.allow_growth = True | |
print('M Starting TF session...') | |
with tf.Session(graph=graph, config=config) as sess: | |
output_tensors = {k: graph.get_tensor_by_name(k + ':0') for k in _OUTPUT_TENSOR_NAMES} | |
image_tensor = graph.get_tensor_by_name('image_tensor:0') | |
total_detections = 0 | |
while not self.stop_thread.is_set(): | |
total_detections += 1 | |
if total_detections % 100 == 0: | |
print(f'D total_detections={total_detections}') | |
try: | |
batch = self.image_q.get(block=True, timeout=2.0) | |
except queue.Empty: | |
continue | |
result = sess.run(output_tensors, feed_dict={image_tensor: [batch]}) | |
self.detection_q.put(result) | |
for i in range(0, len(result['num_detections'])): | |
n = int(result['num_detections'][i]) | |
c = result['detection_classes'][i].astype(int) | |
b = result['detection_boxes'][i] | |
s = result['detection_scores'][i] | |
for j in range(0, n): | |
if s[j] > 0.7: | |
print('%3d %3d %3d %0.2f %s' % (i, j, c[j], s[j], b[j])) | |
if self.stop_thread.is_set(): | |
raise StopIteration() | |
except StopIteration: | |
print(f'D Stopping.') | |
print(f'D Closed detector.') | |
def interrupt(self): | |
print(f'D Interrupting detector.') | |
self.stop_thread.set() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, List, Tuple | |
import PIL | |
from PIL import Image | |
import argparse, itertools | |
import tensorflow as tf | |
import numpy as np | |
# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md | |
# https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_label_map.pbtxt | |
MODEL_DIR = '/media/mikael/Data/ml-data/models' | |
# MODEL_FILE = MODEL_DIR + '/faster_rcnn_resnet101_kitti_2018_01_28/frozen_inference_graph.pb' | |
# MODEL_FILE = MODEL_DIR + '/faster_rcnn_resnet50_lowproposals_coco_2018_01_28/frozen_inference_graph.pb' | |
# MODEL_FILE = MODEL_DIR + '/ssd_mobilenet_v2_coco_2018_03_29/frozen_inference_graph.pb' | |
# MODEL_FILE = MODEL_DIR + '/faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb' | |
# MODEL_FILE = MODEL_DIR + '/faster_rcnn_inception_resnet_v2_atrous_oid_2018_01_28/frozen_inference_graph.pb' | |
# MODEL_FILE = MODEL_DIR + '/faster_rcnn_inception_resnet_v2_atrous_lowproposals_oid_2018_01_28/frozen_inference_graph.pb' | |
# MODEL_FILE = '/home/mikael/ml/train/pets/exported_graphs/frozen_inference_graph.pb' | |
MODEL_FILE = MODEL_DIR + '/ssdlite_mobilenet_v2_coco_2018_05_09/frozen_inference_graph.pb' | |
OUTPUT_TENSOR_NAMES = ['detection_boxes', 'detection_scores', 'num_detections', 'detection_classes'] | |
INPUT_TENSOR_NAME = 'image_tensor' | |
def is_not_none(x): | |
return x is not None | |
def grouper(iterable, n, fillvalue=None): | |
"Collect data into fixed-length chunks or blocks" | |
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" | |
args = [iter(iterable)] * n | |
return itertools.zip_longest(*args, fillvalue=fillvalue) | |
def generate_arrays(files): | |
for filename in files: | |
with Image.open(filename) as image: | |
yield np.asarray(image) | |
def main(args): | |
np.set_printoptions(precision=2) | |
graph = tf.Graph() | |
with open(MODEL_FILE, 'rb') as f: | |
graph_def = tf.GraphDef.FromString(f.read()) | |
with graph.as_default(): | |
tf.import_graph_def(graph_def, name='') | |
output_tensors = {k: graph.get_tensor_by_name(k + ':0') for k in OUTPUT_TENSOR_NAMES} | |
image_tensor = graph.get_tensor_by_name('image_tensor:0') | |
config = tf.ConfigProto() | |
if args.cpu: | |
config.device_count['GPU'] = 0 | |
else: | |
config.gpu_options.per_process_gpu_memory_fraction = 0.2 | |
config.gpu_options.allow_growth = True | |
with tf.Session(graph=graph, config=config) as sess: | |
results = [] | |
all_images = generate_arrays(args.image_files) | |
# ALL OF THE IMAGES IN THE BATCH HAVE TO BE THE SAME SIZE, OR NP WILL WHINE | |
for batch_i in grouper(all_images, args.batch_size, None): | |
batch = list(filter(is_not_none, batch_i)) | |
result = sess.run(output_tensors, feed_dict={image_tensor: batch}) | |
results.append(result) | |
for i in range(0, len(result['num_detections'])): | |
n = int(result['num_detections'][i]) | |
c = result['detection_classes'][i].astype(int) | |
b = result['detection_boxes'][i] | |
s = result['detection_scores'][i] | |
for j in range(0, n): | |
if not args.threshold or s[j] > args.threshold: | |
print('%3d %3d %3d %0.2f %s' % (i, j, c[j], s[j], b[j])) | |
def parse_arguments(): | |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('--cpu', help='Run on CPU', default=False, action='store_true') | |
parser.add_argument('-t', '--threshold', type=float, default=0.5, | |
help='The probability threshold for displaying a prediction.') | |
parser.add_argument('-n', '--batch-size', type=int, default=1, | |
help='How many images to process in one batch.') | |
parser.add_argument('image_files', type=str, nargs='+', help='Image files.') | |
return parser.parse_args() | |
if __name__ == '__main__': | |
main(parse_arguments()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment