Skip to content

Instantly share code, notes, and snippets.

@remi-or
Last active June 10, 2022 19:25
Show Gist options
  • Save remi-or/8b2a6ef3bc59821305498ad627e16502 to your computer and use it in GitHub Desktop.
Save remi-or/8b2a6ef3bc59821305498ad627e16502 to your computer and use it in GitHub Desktop.
def aug_test(self,
imgs: List[Tensor],
img_metas: List[dict],
rescale: bool = False) -> Tensor:
acc_boxes = np.zeros((0, 5))
acc_score = np.zeros((0, self.roi_head.bbox_head.num_classes))
for img, img_meta in zip(imgs, img_metas):
for label, dets in enumerate(self.simple_test(img, img_meta, None, rescale)[0]):
boxes, scores = dets[:, :-1], dets[:, -1]
acc_boxes = np.vstack((acc_boxes, boxes))
full_scores = np.zeros((scores.shape[0], acc_score.shape[1]))
full_scores[:, label] = scores
acc_score = np.vstack((acc_score, full_scores))
bboxes, labels = multiclass_nms_rotated(multi_bboxes=torch.tensor(acc_boxes),
multi_scores=torch.tensor(acc_score),
score_thr=0.5,
nms=Namespace(iou_thr=0.7))
merged_dets = [[] for _ in range(self.roi_head.bbox_head.num_classes)]
for box, label in zip(bboxes, labels):
merged_dets[int(label)].append(box)
for label, bboxes in enumerate(merged_dets):
merged_dets[label] = torch.vstack(bboxes).numpy() if bboxes else np.zeros((0, 6))
return [merged_dets]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment