Last active
April 28, 2020 10:38
-
-
Save pythonlessons/c64bf3dfe31e1f3fdbf111aa6fb59118 to your computer and use it in GitHub Desktop.
Yolo_v3_nms
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def nms(bboxes, iou_threshold, sigma=0.3, method='nms'): | |
| classes_in_img = list(set(bboxes[:, 5])) | |
| best_bboxes = [] | |
| for cls in classes_in_img: | |
| cls_mask = (bboxes[:, 5] == cls) | |
| cls_bboxes = bboxes[cls_mask] | |
| # Process 1: Determine whether the number of bounding boxes is greater than 0 | |
| while len(cls_bboxes) > 0: | |
| # Process 2: Select the bounding box with the highest score according to score order A | |
| max_ind = np.argmax(cls_bboxes[:, 4]) | |
| best_bbox = cls_bboxes[max_ind] | |
| best_bboxes.append(best_bbox) | |
| cls_bboxes = np.concatenate([cls_bboxes[: max_ind], cls_bboxes[max_ind + 1:]]) | |
| # Process 3: Calculate this bounding box A and | |
| # Remain all iou of the bounding box and remove those bounding boxes whose iou value is higher than the threshold | |
| iou = bboxes_iou(best_bbox[np.newaxis, :4], cls_bboxes[:, :4]) | |
| weight = np.ones((len(iou),), dtype=np.float32) | |
| assert method in ['nms', 'soft-nms'] | |
| if method == 'nms': | |
| iou_mask = iou > iou_threshold | |
| weight[iou_mask] = 0.0 | |
| if method == 'soft-nms': | |
| weight = np.exp(-(1.0 * iou ** 2 / sigma)) | |
| cls_bboxes[:, 4] = cls_bboxes[:, 4] * weight | |
| score_mask = cls_bboxes[:, 4] > 0. | |
| cls_bboxes = cls_bboxes[score_mask] | |
| return best_bboxes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment