Created
April 12, 2021 05:19
-
-
Save nvnnghia/88d465db3a2885396e36eda32a1fa790 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
import numpy as np | |
import os | |
import warnings | |
def bb_intersection_over_union(A, B) -> float: | |
xA = max(A[0], B[0]) | |
yA = max(A[1], B[1]) | |
xB = min(A[2], B[2]) | |
yB = min(A[3], B[3]) | |
# compute the area of intersection rectangle | |
interArea = max(0, xB - xA) * max(0, yB - yA) | |
if interArea == 0: | |
return 0.0 | |
# compute the area of both the prediction and ground-truth rectangles | |
boxAArea = (A[2] - A[0]) * (A[3] - A[1]) | |
boxBArea = (B[2] - B[0]) * (B[3] - B[1]) | |
iou = interArea / float(boxAArea + boxBArea - interArea) | |
return iou | |
def prefilter_boxes(boxes, scores, labels, weights, thr): | |
# Create dict with boxes stored by its label | |
new_boxes = dict() | |
for t in range(len(boxes)): | |
if len(boxes[t]) != len(scores[t]): | |
print('Error. Length of boxes arrays not equal to length of scores array: {} != {}'.format(len(boxes[t]), len(scores[t]))) | |
exit() | |
if len(boxes[t]) != len(labels[t]): | |
print('Error. Length of boxes arrays not equal to length of labels array: {} != {}'.format(len(boxes[t]), len(labels[t]))) | |
exit() | |
for j in range(len(boxes[t])): | |
score = scores[t][j] | |
if score < thr: | |
continue | |
label = int(labels[t][j]) | |
box_part = boxes[t][j] | |
x1 = float(box_part[0]) | |
y1 = float(box_part[1]) | |
x2 = float(box_part[2]) | |
y2 = float(box_part[3]) | |
# Box data checks | |
if x2 < x1: | |
warnings.warn('X2 < X1 value in box. Swap them.') | |
x1, x2 = x2, x1 | |
if y2 < y1: | |
warnings.warn('Y2 < Y1 value in box. Swap them.') | |
y1, y2 = y2, y1 | |
if x1 < 0: | |
warnings.warn('X1 < 0 in box. Set it to 0.') | |
x1 = 0 | |
if x1 > 1: | |
warnings.warn('X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.') | |
x1 = 1 | |
if x2 < 0: | |
warnings.warn('X2 < 0 in box. Set it to 0.') | |
x2 = 0 | |
if x2 > 1: | |
warnings.warn('X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.') | |
x2 = 1 | |
if y1 < 0: | |
warnings.warn('Y1 < 0 in box. Set it to 0.') | |
y1 = 0 | |
if y1 > 1: | |
warnings.warn('Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.') | |
y1 = 1 | |
if y2 < 0: | |
warnings.warn('Y2 < 0 in box. Set it to 0.') | |
y2 = 0 | |
if y2 > 1: | |
warnings.warn('Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.') | |
y2 = 1 | |
if (x2 - x1) * (y2 - y1) == 0.0: | |
warnings.warn("Zero area box skipped: {}.".format(box_part)) | |
continue | |
b = [int(label), float(score) * weights[t], x1, y1, x2, y2] | |
if label not in new_boxes: | |
new_boxes[label] = [] | |
new_boxes[label].append(b) | |
# Sort each list in dict by score and transform it to numpy array | |
for k in new_boxes: | |
current_boxes = np.array(new_boxes[k]) | |
new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]] | |
return new_boxes | |
def get_weighted_box(boxes, conf_type='avg'): | |
""" | |
Create weighted box for set of boxes | |
:param boxes: set of boxes to fuse | |
:param conf_type: type of confidence one of 'avg' or 'max' | |
:return: weighted box | |
""" | |
box = np.zeros(6, dtype=np.float32) | |
conf = 0 | |
conf_list = [] | |
for b in boxes: | |
box[2:] += (b[1] * b[2:]) | |
conf += b[1] | |
conf_list.append(b[1]) | |
box[0] = boxes[0][0] | |
if conf_type == 'avg': | |
box[1] = conf / len(boxes) | |
elif conf_type == 'max': | |
box[1] = np.array(conf_list).max() | |
elif conf_type == 'sum': | |
box[1] = np.sum(conf_list) | |
box[2:] /= conf | |
return box | |
def find_matching_box(boxes_list, new_box, match_iou): | |
best_iou = match_iou | |
best_index = -1 | |
for i in range(len(boxes_list)): | |
box = boxes_list[i] | |
if box[0] != new_box[0]: | |
continue | |
iou = bb_intersection_over_union(box[2:], new_box[2:]) | |
if iou > best_iou: | |
best_index = i | |
best_iou = iou | |
return best_index, best_iou | |
def weighted_boxes_fusion_customized(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False): | |
''' | |
:param boxes_list: list of boxes predictions from each model, each box is 4 numbers. | |
It has 3 dimensions (models_number, model_preds, 4) | |
Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1] | |
:param scores_list: list of scores for each model | |
:param labels_list: list of labels for each model | |
:param weights: list of weights for each model. Default: None, which means weight == 1 for each model | |
:param iou_thr: IoU value for boxes to be a match | |
:param skip_box_thr: exclude boxes with score lower than this variable | |
:param conf_type: how to calculate confidence in weighted boxes. 'avg': average value, 'max': maximum value | |
:param allows_overflow: false if we want confidence score not exceed 1.0 | |
:return: boxes: boxes coordinates (Order of boxes: x1, y1, x2, y2). | |
:return: scores: confidence scores | |
:return: labels: boxes labels | |
''' | |
if weights is None: | |
weights = np.ones(len(boxes_list)) | |
if len(weights) != len(boxes_list): | |
print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list))) | |
weights = np.ones(len(boxes_list)) | |
weights = np.array(weights) | |
if conf_type not in ['avg', 'max']: | |
print('Unknown conf_type: {}. Must be "avg" or "max"'.format(conf_type)) | |
exit() | |
filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr) | |
if len(filtered_boxes) == 0: | |
return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,)) | |
overall_boxes = [] | |
for label in filtered_boxes: | |
boxes = filtered_boxes[label] | |
new_boxes = [] | |
weighted_boxes = [] | |
boxes = boxes.tolist() | |
# print(boxes[:3]) | |
boundingBox=[] | |
listBox = [] | |
l=len(boxes) | |
while(l>0): | |
boxPrim=boxes[0] | |
listBox.append(boxPrim) | |
boxes1=boxes[1:] | |
boxes.remove(boxPrim) | |
for box in boxes1: | |
if boxPrim[0]==box[0] and bb_intersection_over_union(boxPrim[2:6], box[2:6]) > iou_thr: | |
listBox.append(box) | |
boxes.remove(box) | |
boundingBox.append(listBox) | |
listBox = [] | |
l=len(boxes) | |
for boxes in boundingBox: | |
box = get_weighted_box(np.array(boxes), conf_type='sum') | |
weighted_boxes.append(box) | |
if len(weighted_boxes)>0: | |
overall_boxes.append(np.array(weighted_boxes)) | |
overall_boxes = np.concatenate(overall_boxes, axis=0) | |
overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]] | |
boxes = overall_boxes[:, 2:] | |
scores = overall_boxes[:, 1] | |
labels = overall_boxes[:, 0] | |
return boxes, scores, labels |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment