Skip to content

Instantly share code, notes, and snippets.

@arunm8489
Created June 5, 2020 06:05
Show Gist options
  • Save arunm8489/583419dfeaa32a88be2a0cfbeea04145 to your computer and use it in GitHub Desktop.
Save arunm8489/583419dfeaa32a88be2a0cfbeea04145 to your computer and use it in GitHub Desktop.
def bounding_box_iou(box1, box2):
"""
Returns the IoU of two bounding boxes
"""
#Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3]
#get the corrdinates of the intersection rectangle
inter_rect_x1 = torch.max(b1_x1, b2_x1)
inter_rect_y1 = torch.max(b1_y1, b2_y1)
inter_rect_x2 = torch.min(b1_x2, b2_x2)
inter_rect_y2 = torch.min(b1_y2, b2_y2)
#Intersection area
intersection_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0)
#Union Area
b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1)
iou = intersection_area / (b1_area + b2_area - intersection_area)
return iou
def final_detection(prediction, confidence_threshold, num_classes, nms_conf = 0.4):
# taking only values above a particular threshold and set rest everything to zero
mask = (prediction[:,:,4] > confidence_threshold).float().unsqueeze(2)
prediction = prediction*_mask
#(center x, center y, height, width) attributes of our boxes,
#to (top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)
box_corner = prediction.new(prediction.shape)
box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)
box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)
box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2)
box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)
prediction[:,:,:4] = box_corner[:,:,:4]
batch_size = prediction.size(0)
write = False
# we can do non max suppression only on individual images so we will loop through images
for ind in range(batch_size):
image_pred = prediction[ind]
# we will take only those rows with maximm class probability
# and corresponding index
max_conf, max_conf_score = torch.max(image_pred[:,5:5+ num_classes], 1)
max_conf = max_conf.float().unsqueeze(1)
max_conf_score = max_conf_score.float().unsqueeze(1)
combined = (image_pred[:,:5], max_conf, max_conf_score)
# concatinating index values and max probability with box cordinates as columns
image_pred = torch.cat(combined, 1)
#Remember we had set the bounding box rows having a object confidence
# less than the threshold to zero? Let's get rid of them.
non_zero_index = (torch.nonzero(image_pred[:,4])) # non_zero_ind will give the indexes
image_pred_ = image_pred[non_zero_index.squeeze(),:].view(-1,7)
try:
#Get the various classes detected in the image
img_classes = torch.unique(image_pred_[:,-1]) # -1 index holds the class index
except:
continue
for cls in img_classes:
#perform NMS
#get the detections with one particular class
cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1)
# taking the non zero indexes
class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
image_pred_class = image_pred_[class_mask_ind].view(-1,7)
# sort them based on probability #getting index
conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1]
image_pred_class = image_pred_class[conf_sort_index]
idx = image_pred_class.size(0)
for i in range(idx):
#Get the IOUs of all boxes that come after the one we are looking at
#in the loop
try:
ious = bounding_box_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:])
except ValueError:
break
except IndexError:
break
#Zero out all the detections that have IoU > treshhold
iou_mask = (ious < nms_conf).float().unsqueeze(1)
image_pred_class[i+1:] *= iou_mask
#Remove the non-zero entries
non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze()
image_pred_class = image_pred_class[non_zero_ind].view(-1,7)
#Concatenate the batch_id of the image to the detection
#this helps us identify which image does the detection correspond to
#We use a linear straucture to hold ALL the detections from the batch
#the batch_dim is flattened
#batch is identified by extra batch column
#creating a row with index of images
batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind)
seq = batch_ind, image_pred_class
if not write:
output = torch.cat(seq,1)
write = True
else:
out = torch.cat(seq,1)
output = torch.cat((output,out))
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment