-
-
Save arunm8489/583419dfeaa32a88be2a0cfbeea04145 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def bounding_box_iou(box1, box2): | |
""" | |
Returns the IoU of two bounding boxes | |
""" | |
#Get the coordinates of bounding boxes | |
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3] | |
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3] | |
#get the corrdinates of the intersection rectangle | |
inter_rect_x1 = torch.max(b1_x1, b2_x1) | |
inter_rect_y1 = torch.max(b1_y1, b2_y1) | |
inter_rect_x2 = torch.min(b1_x2, b2_x2) | |
inter_rect_y2 = torch.min(b1_y2, b2_y2) | |
#Intersection area | |
intersection_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0) | |
#Union Area | |
b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1) | |
b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1) | |
iou = intersection_area / (b1_area + b2_area - intersection_area) | |
return iou | |
def final_detection(prediction, confidence_threshold, num_classes, nms_conf = 0.4): | |
# taking only values above a particular threshold and set rest everything to zero | |
mask = (prediction[:,:,4] > confidence_threshold).float().unsqueeze(2) | |
prediction = prediction*_mask | |
#(center x, center y, height, width) attributes of our boxes, | |
#to (top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y) | |
box_corner = prediction.new(prediction.shape) | |
box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2) | |
box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2) | |
box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2) | |
box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2) | |
prediction[:,:,:4] = box_corner[:,:,:4] | |
batch_size = prediction.size(0) | |
write = False | |
# we can do non max suppression only on individual images so we will loop through images | |
for ind in range(batch_size): | |
image_pred = prediction[ind] | |
# we will take only those rows with maximm class probability | |
# and corresponding index | |
max_conf, max_conf_score = torch.max(image_pred[:,5:5+ num_classes], 1) | |
max_conf = max_conf.float().unsqueeze(1) | |
max_conf_score = max_conf_score.float().unsqueeze(1) | |
combined = (image_pred[:,:5], max_conf, max_conf_score) | |
# concatinating index values and max probability with box cordinates as columns | |
image_pred = torch.cat(combined, 1) | |
#Remember we had set the bounding box rows having a object confidence | |
# less than the threshold to zero? Let's get rid of them. | |
non_zero_index = (torch.nonzero(image_pred[:,4])) # non_zero_ind will give the indexes | |
image_pred_ = image_pred[non_zero_index.squeeze(),:].view(-1,7) | |
try: | |
#Get the various classes detected in the image | |
img_classes = torch.unique(image_pred_[:,-1]) # -1 index holds the class index | |
except: | |
continue | |
for cls in img_classes: | |
#perform NMS | |
#get the detections with one particular class | |
cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1) | |
# taking the non zero indexes | |
class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze() | |
image_pred_class = image_pred_[class_mask_ind].view(-1,7) | |
# sort them based on probability #getting index | |
conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1] | |
image_pred_class = image_pred_class[conf_sort_index] | |
idx = image_pred_class.size(0) | |
for i in range(idx): | |
#Get the IOUs of all boxes that come after the one we are looking at | |
#in the loop | |
try: | |
ious = bounding_box_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:]) | |
except ValueError: | |
break | |
except IndexError: | |
break | |
#Zero out all the detections that have IoU > treshhold | |
iou_mask = (ious < nms_conf).float().unsqueeze(1) | |
image_pred_class[i+1:] *= iou_mask | |
#Remove the non-zero entries | |
non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze() | |
image_pred_class = image_pred_class[non_zero_ind].view(-1,7) | |
#Concatenate the batch_id of the image to the detection | |
#this helps us identify which image does the detection correspond to | |
#We use a linear straucture to hold ALL the detections from the batch | |
#the batch_dim is flattened | |
#batch is identified by extra batch column | |
#creating a row with index of images | |
batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind) | |
seq = batch_ind, image_pred_class | |
if not write: | |
output = torch.cat(seq,1) | |
write = True | |
else: | |
out = torch.cat(seq,1) | |
output = torch.cat((output,out)) | |
return output | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment