Created
March 5, 2022 18:36
-
-
Save Alyetama/068054632e6ceacbf066664e2c18e920 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
r""" | |
# SOURCE: https://github.com/microsoft/CameraTraps/blob/main/detection/run_tf_detector.py | |
# EDITED TO WORK WITH SHARED GPUs | |
#----------------------------------------------------------------------------- | |
Module to run a TensorFlow animal detection model on images. | |
The class TFDetector contains functions to load a TensorFlow detection model and | |
run inference. The main function in this script also renders the predicted | |
bounding boxes on images and saves the resulting images (with bounding boxes). | |
This script is not a good way to process lots of images (tens of thousands, | |
say). It does not facilitate checkpointing the results so if it crashes you | |
would have to start from scratch. If you want to run a detector (e.g., ours) | |
on lots of images, you should check out: | |
1) run_tf_detector_batch.py (for local execution) | |
2) https://github.com/microsoft/CameraTraps/tree/master/api/batch_processing | |
(for running large jobs on Azure ML) | |
To run this script, we recommend you set up a conda virtual environment | |
following instructions in the Installation section on the main README, using | |
`environment-detector.yml` as the environment file where asked. | |
This is a good way to test our detector on a handful of images and get | |
super-satisfying, graphical results. It's also a good way to see how fast a | |
detector model will run on a particular machine. | |
If you would like to *not* use the GPU on the machine, set the environment | |
variable CUDA_VISIBLE_DEVICES to "-1". | |
If no output directory is specified, writes detections for c:\foo\bar.jpg to | |
c:\foo\bar_detections.jpg. | |
This script will only consider detections with > 0.1 confidence at all times. | |
The `threshold` you provide is only for rendering the results. If you need to | |
see lower-confidence detections, you can change | |
DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD. | |
Reference: | |
https://github.com/tensorflow/models/blob/master/research/object_detection/inference/detection_inference.py | |
""" | |
#%% Constants, imports, environment | |
import argparse | |
import glob | |
import os | |
import statistics | |
import sys | |
import time | |
import warnings | |
import humanfriendly | |
import numpy as np | |
from tqdm import tqdm | |
from ct_utils import truncate_float | |
import visualization.visualization_utils as viz_utils | |
# ignoring all "PIL cannot read EXIF metainfo for the images" warnings | |
warnings.filterwarnings('ignore', '(Possibly )?corrupt EXIF data', UserWarning) | |
# Metadata Warning, tag 256 had too many entries: 42, expected 1 | |
warnings.filterwarnings('ignore', 'Metadata warning', UserWarning) | |
# Numpy FutureWarnings from tensorflow import | |
warnings.filterwarnings('ignore', category=FutureWarning) | |
# Useful hack to force CPU inference | |
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
import tensorflow.compat.v1 as tf | |
from tensorflow.compat.v1 import ConfigProto | |
print('TensorFlow version:', tf.__version__) | |
print('Is GPU available? tf.test.is_gpu_available:', tf.test.is_gpu_available()) | |
#%% Classes | |
class ImagePathUtils: | |
"""A collection of utility functions supporting this stand-alone script""" | |
# Stick this into filenames before the extension for the rendered result | |
DETECTION_FILENAME_INSERT = '_detections' | |
image_extensions = ['.jpg', '.jpeg', '.gif', '.png'] | |
@staticmethod | |
def is_image_file(s): | |
""" | |
Check a file's extension against a hard-coded set of image file extensions | |
""" | |
ext = os.path.splitext(s)[1] | |
return ext.lower() in ImagePathUtils.image_extensions | |
@staticmethod | |
def find_image_files(strings): | |
""" | |
Given a list of strings that are potentially image file names, look for strings | |
that actually look like image file names (based on extension). | |
""" | |
return [s for s in strings if ImagePathUtils.is_image_file(s)] | |
@staticmethod | |
def find_images(dir_name, recursive=False): | |
""" | |
Find all files in a directory that look like image file names | |
""" | |
if recursive: | |
strings = glob.glob(os.path.join(dir_name, '**', '*.*'), recursive=True) | |
else: | |
strings = glob.glob(os.path.join(dir_name, '*.*')) | |
image_strings = ImagePathUtils.find_image_files(strings) | |
return image_strings | |
class TFDetector: | |
""" | |
A detector model loaded at the time of initialization. It is intended to be used with | |
the MegaDetector (TF). The inference batch size is set to 1; code needs to be modified | |
to support larger batch sizes, including resizing appropriately. | |
""" | |
# Number of decimal places to round to for confidence and bbox coordinates | |
CONF_DIGITS = 3 | |
COORD_DIGITS = 4 | |
# MegaDetector was trained with batch size of 1, and the resizing function is a part | |
# of the inference graph | |
BATCH_SIZE = 1 | |
# An enumeration of failure reasons | |
FAILURE_TF_INFER = 'Failure TF inference' | |
FAILURE_IMAGE_OPEN = 'Failure image access' | |
DEFAULT_RENDERING_CONFIDENCE_THRESHOLD = 0.85 # to render bounding boxes | |
DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD = 0.1 # to include in the output json file | |
DEFAULT_DETECTOR_LABEL_MAP = { | |
'1': 'animal', | |
'2': 'person', | |
'3': 'vehicle' # available in megadetector v4+ | |
} | |
NUM_DETECTOR_CATEGORIES = 4 # animal, person, group, vehicle - for color assignment | |
def __init__(self, model_path): | |
"""Loads model from model_path and starts a tf.Session with this graph. Obtains | |
input and output tensor handles.""" | |
detection_graph = TFDetector.__load_model(model_path) | |
config = ConfigProto() | |
config.gpu_options.allow_growth = True | |
self.tf_session = tf.Session(config=config, graph=detection_graph) | |
self.image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') | |
self.box_tensor = detection_graph.get_tensor_by_name('detection_boxes:0') | |
self.score_tensor = detection_graph.get_tensor_by_name('detection_scores:0') | |
self.class_tensor = detection_graph.get_tensor_by_name('detection_classes:0') | |
@staticmethod | |
def round_and_make_float(d, precision=4): | |
return truncate_float(float(d), precision=precision) | |
@staticmethod | |
def __convert_coords(tf_coords): | |
"""Converts coordinates from the model's output format [y1, x1, y2, x2] to the | |
format used by our API and MegaDB: [x1, y1, width, height]. All coordinates | |
(including model outputs) are normalized in the range [0, 1]. | |
Args: | |
tf_coords: np.array of predicted bounding box coordinates from the TF detector, | |
has format [y1, x1, y2, x2] | |
Returns: list of Python float, predicted bounding box coordinates [x1, y1, width, height] | |
""" | |
# change from [y1, x1, y2, x2] to [x1, y1, width, height] | |
width = tf_coords[3] - tf_coords[1] | |
height = tf_coords[2] - tf_coords[0] | |
new = [tf_coords[1], tf_coords[0], width, height] # must be a list instead of np.array | |
# convert numpy floats to Python floats | |
for i, d in enumerate(new): | |
new[i] = TFDetector.round_and_make_float(d, precision=TFDetector.COORD_DIGITS) | |
return new | |
@staticmethod | |
def convert_to_tf_coords(array): | |
"""From [x1, y1, width, height] to [y1, x1, y2, x2], where x1 is x_min, x2 is x_max | |
This is an extraneous step as the model outputs [y1, x1, y2, x2] but were converted to the API | |
output format - only to keep the interface of the sync API. | |
""" | |
x1 = array[0] | |
y1 = array[1] | |
width = array[2] | |
height = array[3] | |
x2 = x1 + width | |
y2 = y1 + height | |
return [y1, x1, y2, x2] | |
@staticmethod | |
def __load_model(model_path): | |
"""Loads a detection model (i.e., create a graph) from a .pb file. | |
Args: | |
model_path: .pb file of the model. | |
Returns: the loaded graph. | |
""" | |
print('TFDetector: Loading graph...') | |
detection_graph = tf.Graph() | |
with detection_graph.as_default(): | |
od_graph_def = tf.GraphDef() | |
with tf.gfile.GFile(model_path, 'rb') as fid: | |
serialized_graph = fid.read() | |
od_graph_def.ParseFromString(serialized_graph) | |
tf.import_graph_def(od_graph_def, name='') | |
print('TFDetector: Detection graph loaded.') | |
return detection_graph | |
def _generate_detections_one_image(self, image): | |
np_im = np.asarray(image, np.uint8) | |
im_w_batch_dim = np.expand_dims(np_im, axis=0) | |
# need to change the above line to the following if supporting a batch size > 1 and resizing to the same size | |
# np_images = [np.asarray(image, np.uint8) for image in images] | |
# images_stacked = np.stack(np_images, axis=0) if len(images) > 1 else np.expand_dims(np_images[0], axis=0) | |
# performs inference | |
(box_tensor_out, score_tensor_out, class_tensor_out) = self.tf_session.run( | |
[self.box_tensor, self.score_tensor, self.class_tensor], | |
feed_dict={self.image_tensor: im_w_batch_dim}) | |
return box_tensor_out, score_tensor_out, class_tensor_out | |
def generate_detections_one_image(self, image, image_id, | |
detection_threshold=DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD): | |
"""Apply the detector to an image. | |
Args: | |
image: the PIL Image object | |
image_id: a path to identify the image; will be in the "file" field of the output object | |
detection_threshold: confidence above which to include the detection proposal | |
Returns: | |
A dict with the following fields, see the 'images' key in https://github.com/microsoft/CameraTraps/tree/master/api/batch_processing#batch-processing-api-output-format | |
- 'file' (always present) | |
- 'max_detection_conf' | |
- 'detections', which is a list of detection objects containing keys 'category', 'conf' and 'bbox' | |
- 'failure' | |
""" | |
result = { | |
'file': image_id | |
} | |
try: | |
b_box, b_score, b_class = self._generate_detections_one_image(image) | |
# our batch size is 1; need to loop the batch dim if supporting batch size > 1 | |
boxes, scores, classes = b_box[0], b_score[0], b_class[0] | |
detections_cur_image = [] # will be empty for an image with no confident detections | |
max_detection_conf = 0.0 | |
for b, s, c in zip(boxes, scores, classes): | |
if s > detection_threshold: | |
detection_entry = { | |
'category': str(int(c)), # use string type for the numerical class label, not int | |
'conf': truncate_float(float(s), # cast to float for json serialization | |
precision=TFDetector.CONF_DIGITS), | |
'bbox': TFDetector.__convert_coords(b) | |
} | |
detections_cur_image.append(detection_entry) | |
if s > max_detection_conf: | |
max_detection_conf = s | |
result['max_detection_conf'] = truncate_float(float(max_detection_conf), | |
precision=TFDetector.CONF_DIGITS) | |
result['detections'] = detections_cur_image | |
except Exception as e: | |
result['failure'] = TFDetector.FAILURE_TF_INFER | |
print('TFDetector: image {} failed during inference: {}'.format(image_id, str(e))) | |
return result | |
#%% Main function | |
def load_and_run_detector(model_file, image_file_names, output_dir, | |
render_confidence_threshold=TFDetector.DEFAULT_RENDERING_CONFIDENCE_THRESHOLD, | |
crop_images=False): | |
"""Load and run detector on target images, and visualize the results.""" | |
if len(image_file_names) == 0: | |
print('Warning: no files available') | |
return | |
start_time = time.time() | |
tf_detector = TFDetector(model_file) | |
elapsed = time.time() - start_time | |
print('Loaded model in {}'.format(humanfriendly.format_timespan(elapsed))) | |
detection_results = [] | |
time_load = [] | |
time_infer = [] | |
# Dictionary mapping output file names to a collision-avoidance count. | |
# | |
# Since we'll be writing a bunch of files to the same folder, we rename | |
# as necessary to avoid collisions. | |
output_filename_collision_counts = {} | |
def input_file_to_detection_file(fn, crop_index=-1): | |
"""Creates unique file names for output files. | |
This function does 3 things: | |
1) If the --crop flag is used, then each input image may produce several output | |
crops. For example, if foo.jpg has 3 detections, then this function should | |
get called 3 times, with crop_index taking on 0, 1, then 2. Each time, this | |
function appends crop_index to the filename, resulting in | |
foo_crop00_detections.jpg | |
foo_crop01_detections.jpg | |
foo_crop02_detections.jpg | |
2) If the --recursive flag is used, then the same file (base)name may appear | |
multiple times. However, we output into a single flat folder. To avoid | |
filename collisions, we prepend an integer prefix to duplicate filenames: | |
foo_crop00_detections.jpg | |
0000_foo_crop00_detections.jpg | |
0001_foo_crop00_detections.jpg | |
3) Prepends the output directory: | |
out_dir/foo_crop00_detections.jpg | |
Args: | |
fn: str, filename | |
crop_index: int, crop number | |
Returns: output file path | |
""" | |
fn = os.path.basename(fn).lower() | |
name, ext = os.path.splitext(fn) | |
if crop_index >= 0: | |
name += '_crop{:0>2d}'.format(crop_index) | |
fn = '{}{}{}'.format(name, ImagePathUtils.DETECTION_FILENAME_INSERT, '.jpg') | |
if fn in output_filename_collision_counts: | |
n_collisions = output_filename_collision_counts[fn] | |
fn = '{:0>4d}'.format(n_collisions) + '_' + fn | |
output_filename_collision_counts[fn] += 1 | |
else: | |
output_filename_collision_counts[fn] = 0 | |
fn = os.path.join(output_dir, fn) | |
return fn | |
for im_file in tqdm(image_file_names): | |
try: | |
start_time = time.time() | |
image = viz_utils.load_image(im_file) | |
elapsed = time.time() - start_time | |
time_load.append(elapsed) | |
except Exception as e: | |
print('Image {} cannot be loaded. Exception: {}'.format(im_file, e)) | |
result = { | |
'file': im_file, | |
'failure': TFDetector.FAILURE_IMAGE_OPEN | |
} | |
detection_results.append(result) | |
continue | |
try: | |
start_time = time.time() | |
result = tf_detector.generate_detections_one_image(image, im_file) | |
detection_results.append(result) | |
elapsed = time.time() - start_time | |
time_infer.append(elapsed) | |
except Exception as e: | |
print('An error occurred while running the detector on image {}. Exception: {}'.format(im_file, e)) | |
continue | |
try: | |
if crop_images: | |
images_cropped = viz_utils.crop_image(result['detections'], image) | |
for i_crop, cropped_image in enumerate(images_cropped): | |
output_full_path = input_file_to_detection_file(im_file, i_crop) | |
cropped_image.save(output_full_path) | |
else: | |
# image is modified in place | |
viz_utils.render_detection_bounding_boxes(result['detections'], image, | |
label_map=TFDetector.DEFAULT_DETECTOR_LABEL_MAP, | |
confidence_threshold=render_confidence_threshold) | |
output_full_path = input_file_to_detection_file(im_file) | |
image.save(output_full_path) | |
except Exception as e: | |
print('Visualizing results on the image {} failed. Exception: {}'.format(im_file, e)) | |
continue | |
# ...for each image | |
ave_time_load = statistics.mean(time_load) | |
ave_time_infer = statistics.mean(time_infer) | |
if len(time_load) > 1 and len(time_infer) > 1: | |
std_dev_time_load = humanfriendly.format_timespan(statistics.stdev(time_load)) | |
std_dev_time_infer = humanfriendly.format_timespan(statistics.stdev(time_infer)) | |
else: | |
std_dev_time_load = 'not available' | |
std_dev_time_infer = 'not available' | |
print('On average, for each image,') | |
print('- loading took {}, std dev is {}'.format(humanfriendly.format_timespan(ave_time_load), | |
std_dev_time_load)) | |
print('- inference took {}, std dev is {}'.format(humanfriendly.format_timespan(ave_time_infer), | |
std_dev_time_infer)) | |
#%% Command-line driver | |
def main(): | |
parser = argparse.ArgumentParser( | |
description='Module to run a TF animal detection model on images') | |
parser.add_argument( | |
'detector_file', | |
help='Path to .pb TensorFlow detector model file') | |
group = parser.add_mutually_exclusive_group(required=True) # must specify either an image file or a directory | |
group.add_argument( | |
'--image_file', | |
help='Single file to process, mutually exclusive with --image_dir') | |
group.add_argument( | |
'--image_dir', | |
help='Directory to search for images, with optional recursion by adding --recursive') | |
parser.add_argument( | |
'--recursive', | |
action='store_true', | |
help='Recurse into directories, only meaningful if using --image_dir') | |
parser.add_argument( | |
'--output_dir', | |
help='Directory for output images (defaults to same as input)') | |
parser.add_argument( | |
'--threshold', | |
type=float, | |
default=TFDetector.DEFAULT_RENDERING_CONFIDENCE_THRESHOLD, | |
help=('Confidence threshold between 0 and 1.0; only render boxes above this confidence' | |
' (but only boxes above 0.1 confidence will be considered at all)')) | |
parser.add_argument( | |
'--crop', | |
default=False, | |
action="store_true", | |
help=('If set, produces separate output images for each crop, ' | |
'rather than adding bounding boxes to the original image')) | |
if len(sys.argv[1:]) == 0: | |
parser.print_help() | |
parser.exit() | |
args = parser.parse_args() | |
assert os.path.exists(args.detector_file), 'detector_file specified does not exist' | |
assert 0.0 < args.threshold <= 1.0, 'Confidence threshold needs to be between 0 and 1' # Python chained comparison | |
if args.image_file: | |
image_file_names = [args.image_file] | |
else: | |
image_file_names = ImagePathUtils.find_images(args.image_dir, args.recursive) | |
print('Running detector on {} images...'.format(len(image_file_names))) | |
if args.output_dir: | |
os.makedirs(args.output_dir, exist_ok=True) | |
else: | |
if args.image_dir: | |
args.output_dir = args.image_dir | |
else: | |
# but for a single image, args.image_dir is also None | |
args.output_dir = os.path.dirname(args.image_file) | |
load_and_run_detector(model_file=args.detector_file, | |
image_file_names=image_file_names, | |
output_dir=args.output_dir, | |
render_confidence_threshold=args.threshold, | |
crop_images=args.crop) | |
if __name__ == '__main__': | |
main() | |
#%% Interactive driver | |
if False: | |
#%% | |
model_file = r'c:\temp\models\md_v4.1.0.pb' | |
image_file_names = ImagePathUtils.find_images(r'c:\temp\demo_images\ssverymini') | |
output_dir = r'c:\temp\demo_images\ssverymini' | |
render_confidence_threshold = 0.8 | |
crop_images = True | |
load_and_run_detector(model_file=model_file, | |
image_file_names=image_file_names, | |
output_dir=output_dir, | |
render_confidence_threshold=render_confidence_threshold, | |
crop_images=crop_images) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment