This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for i in range(idx): | |
#Get the IOUs of all boxes that come after the one we are looking at the loop | |
try: | |
ious = bounding_box_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:]) | |
except ValueError: | |
break | |
except IndexError: | |
break | |
#Zero out all the detections that have IoU > treshhold | |
iou_mask = (ious < nms_conf).float().unsqueeze(1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# sort them based on probability #getting index | |
conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1] | |
image_pred_class = image_pred_class[conf_sort_index] | |
idx = image_pred_class.size(0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#perform NMS | |
#get the detections with one particular class | |
cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1) | |
# taking the non zero indexes | |
class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze() | |
image_pred_class = image_pred_[class_mask_ind] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
try: | |
#Get the various classes detected in the image | |
img_classes = torch.unique(image_pred_[:,-1]) # -1 index holds the class index | |
except: | |
continue |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
non_zero_index = (torch.nonzero(image_pred[:,4])) # non_zero_ind will give the indexes | |
image_pred_ = image_pred[non_zero_index.squeeze(),:].view(-1,7) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
batch_size = prediction.size(0) | |
write = False | |
# we can do non max suppression only on individual images so we will loop through images | |
for ind in range(batch_size): | |
image_pred = prediction[ind] | |
# we will take only those rows with maximm class probability | |
# and corresponding index | |
max_conf, max_conf_score = torch.max(image_pred[:,5:5+ num_classes], 1) | |
max_conf = max_conf.float().unsqueeze(1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#(center x, center y, height, width) attributes of our boxes, | |
#to (top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y) | |
box_corner = prediction.new(prediction.shape) | |
box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2) | |
box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2) | |
box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2) | |
box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2) | |
prediction[:,:,:4] = box_corner[:,:,:4] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# taking only values above a particular threshold and set rest everything to zero | |
mask = (prediction[:,:,4] > confidence_threshold).float().unsqueeze(2) | |
prediction = prediction*_mask |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
img = cv2.imread("dog-cycle-car.png") | |
img = cv2.resize(img, (416,416)) | |
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) | |
img = img.transpose((2,0,1)) | |
img = img[np.newaxis,:,:,:]/255.0 | |
img = torch.from_numpy(img).float() | |
model = Darknet("yolov3.cfg") | |
pred = model(img, torch.cuda.is_available()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Darknet(nn.Module): | |
def __init__(self, cfgfile): | |
super(Darknet, self).__init__() | |
self.blocks = parse_cfg(cfgfile) | |
self.net_info, self.module_list = create_model(self.blocks) | |
def forward(self, x, CUDA=False): | |
modules = self.blocks[1:] | |
#We cache the outputs for the route layer | |
outputs = {} |