Created
March 25, 2018 01:22
-
-
Save Tony607/49916eed48668000f9ee17903f6a3fdd to your computer and use it in GitHub Desktop.
Gentle guide on how YOLO Object Localization works with Keras (Part 2)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def yolo_non_max_suppression(scores, boxes, classes, max_boxes = 10, iou_threshold = 0.5): | |
""" | |
Applies Non-max suppression (NMS) to set of boxes | |
Arguments: | |
scores -- tensor of shape (None,), output of yolo_filter_boxes() | |
boxes -- tensor of shape (None, 4), output of yolo_filter_boxes() that have been scaled to the image size (see later) | |
classes -- tensor of shape (None,), output of yolo_filter_boxes() | |
max_boxes -- integer, maximum number of predicted boxes you'd like | |
iou_threshold -- real value, "intersection over union" threshold used for NMS filtering | |
Returns: | |
scores -- tensor of shape (, None), predicted score for each box | |
boxes -- tensor of shape (4, None), predicted box coordinates | |
classes -- tensor of shape (, None), predicted class for each box | |
Note: The "None" dimension of the output tensors has obviously to be less than max_boxes. Note also that this | |
function will transpose the shapes of scores, boxes, classes. This is made for convenience. | |
""" | |
max_boxes_tensor = K.variable(max_boxes, dtype='int32') # tensor to be used in tf.image.non_max_suppression() | |
K.get_session().run(tf.variables_initializer([max_boxes_tensor])) # initialize variable max_boxes_tensor | |
# Use tf.image.non_max_suppression() to get the list of indices corresponding to boxes you keep | |
nms_indices = tf.image.non_max_suppression( boxes, scores, max_boxes_tensor, iou_threshold) | |
# Use K.gather() to select only nms_indices from scores, boxes and classes | |
scores = K.gather(scores, nms_indices) | |
boxes = K.gather(boxes, nms_indices) | |
classes = K.gather(classes, nms_indices) | |
return scores, boxes, classes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment