Skip to content

Instantly share code, notes, and snippets.

@YimianDai
Last active August 30, 2019 21:21
Show Gist options
  • Save YimianDai/f42ccb89cd041536dbb229c2f742513b to your computer and use it in GitHub Desktop.
Save YimianDai/f42ccb89cd041536dbb229c2f742513b to your computer and use it in GitHub Desktop.
ErrorAnalysis

第一阶段:我应该只需要计算 TP、FP 这些数值

  1. self._n_pos[l] 就是 True Object 的个数
  2. np.sum(self._match[l] == 1) 就是 True Positive 的个数
  3. np.sum(self._match[l] == 0) 就是 False Positive 的个数
  4. len(self._match[l]) 就是 Positive 的个数
  5. recall 就是 #TP / #T,precision 就是 #TP / #P

因此,我感觉功能只要如下就可以了

  1. __init__ 负责初始化一些量,并 call reset 函数
  2. reset 负责将一些内在统计量初始化
  3. update 负责更新 self._n_posself._match 这些统计量
  4. get 负责返回每类的 TP、FP、Recall、Precision 这些
  5. _update 负责根据统计量计算出 TP、FP、Recall、Precision 这些
from __future__ import division

from collections import defaultdict
import numpy as np
import mxnet as mx
from gluoncv.utils.bbox import bbox_iou

class ErrorAnalysis(mx.metric.EvalMetric):
    """
    Calculate mean AP for object detection task

    Parameters:
    ---------
    iou_thresh : float
        IOU overlap threshold for TP
    class_names : list of str
        optional, if provided, will print out AP for each class
    """
    def __init__(self, iou_thresh=0.5, score_thresh=0.5, class_names=None):
        super(ErrorAnalysis, self).__init__('ErrorAnalysis')
        if class_names is None:
            self.num = None
        else:
            assert isinstance(class_names, (list, tuple))
            for name in class_names:
                assert isinstance(name, str), "must provide names as str"
            num = len(class_names)
            self.name = list(class_names)
            self.num = num
        self.reset()
        self.iou_thresh = iou_thresh
        self.score_thresh = score_thresh
        self.class_names = class_names

    def reset(self):
        """Clear the internal statistics to initial state."""
        if getattr(self, 'num', None) is None:
            self.num_inst = 0
            self.sum_metric = 0.0
        else:
            self.num_inst = [0] * self.num
            self.sum_metric = [0.0] * self.num
        self._n_pos = defaultdict(int)
        self._score = defaultdict(list)
        self._match = defaultdict(list)

        self._trues = defaultdict(list)
        self._positives = defaultdict(list)
        self._true_positives = defaultdict(list)
        self._false_positives = defaultdict(list)
        self._false_negatives = defaultdict(list)
        self._precsions = defaultdict(list)
        self._recalls = defaultdict(list)



    def get(self):
        """Get the current evaluation result.

        Returns
        -------
        name : str
           Name of the metric.
        value : float
           Value of the evaluation.
        """

        self._update() # update metric at this time
        for l in range(self.num):
            print(self.name[l], ": #True", self._trues[l],
                  ", #Positive", self._positives[l],
                  ", #True_Positive", self._true_positives[l],
                  ", #False_Positive", self._false_positives[l],
                  ", #False_Negative", self._false_negatives[l],
                  ", Precision", self._precsions[l],
                  ", Recall", self._recalls[l])
        # if self.num is None:
        #     if self.num_inst == 0:
        #         return (self.name, float('nan'))
        #     else:
        #         return (self.name, self.sum_metric / self.num_inst)
        # else:
        #     names = ['%s'%(self.name[i]) for i in range(self.num)]
        #     values = [x / y if y != 0 else float('nan') \
        #         for x, y in zip(self.sum_metric, self.num_inst)]
        #     return (names, values)


    # pylint: disable=arguments-differ, too-many-nested-blocks
    def update(self, pred_bboxes, pred_labels, pred_scores,
               gt_bboxes, gt_labels, gt_difficults=None):
        """Update internal buffer with latest prediction and gt pairs.

        Parameters
        ----------
        pred_bboxes : mxnet.NDArray or numpy.ndarray
            Prediction bounding boxes with shape `B, N, 4`.
            Where B is the size of mini-batch, N is the number of bboxes.
        pred_labels : mxnet.NDArray or numpy.ndarray
            Prediction bounding boxes labels with shape `B, N`.
        pred_scores : mxnet.NDArray or numpy.ndarray
            Prediction bounding boxes scores with shape `B, N`.
        gt_bboxes : mxnet.NDArray or numpy.ndarray
            Ground-truth bounding boxes with shape `B, M, 4`.
            Where B is the size of mini-batch, M is the number of ground-truths.
        gt_labels : mxnet.NDArray or numpy.ndarray
            Ground-truth bounding boxes labels with shape `B, M`.
        gt_difficults : mxnet.NDArray or numpy.ndarray, optional, default is None
            Ground-truth bounding boxes difficulty labels with shape `B, M`.

        """
        def as_numpy(a):
            """Convert a (list of) mx.NDArray into numpy.ndarray"""
            if isinstance(a, (list, tuple)):
                out = [x.asnumpy() if isinstance(x, mx.nd.NDArray) else x for x in a]
                try:
                    out = np.concatenate(out, axis=0)
                except ValueError:
                    out = np.array(out)
                return out
            elif isinstance(a, mx.nd.NDArray):
                a = a.asnumpy()
            return a

        if gt_difficults is None:
            gt_difficults = [None for _ in as_numpy(gt_labels)]

        if isinstance(gt_labels, list):
            if len(gt_difficults) * gt_difficults[0].shape[0] != \
                    len(gt_labels) * gt_labels[0].shape[0]:
                gt_difficults = [None] * len(gt_labels) * gt_labels[0].shape[0]


        for pred_bbox, pred_label, pred_score, gt_bbox, gt_label, gt_difficult in zip(
                *[as_numpy(x) for x in [pred_bboxes, pred_labels, pred_scores,
                                        gt_bboxes, gt_labels, gt_difficults]]):
            # strip padding -1 for pred and gt
            valid_pred = np.where(pred_label.flat >= 0)[0]
            pred_bbox = pred_bbox[valid_pred, :]
            pred_label = pred_label.flat[valid_pred].astype(int)
            pred_score = pred_score.flat[valid_pred]
            valid_gt = np.where(gt_label.flat >= 0)[0]
            gt_bbox = gt_bbox[valid_gt, :]
            gt_label = gt_label.flat[valid_gt].astype(int)
            if gt_difficult is None:
                gt_difficult = np.zeros(gt_bbox.shape[0])
            else:
                gt_difficult = gt_difficult.flat[valid_gt]

            for l in np.unique(np.concatenate((pred_label, gt_label)).astype(int)):
                pred_mask_l = pred_label == l
                pred_bbox_l = pred_bbox[pred_mask_l]
                pred_score_l = pred_score[pred_mask_l]
                # sort by score
                order = pred_score_l.argsort()[::-1]
                pred_bbox_l = pred_bbox_l[order]
                pred_score_l = pred_score_l[order]

                gt_mask_l = gt_label == l
                gt_bbox_l = gt_bbox[gt_mask_l]
                gt_difficult_l = gt_difficult[gt_mask_l]

                self._n_pos[l] += np.logical_not(gt_difficult_l).sum()
                self._score[l].extend(pred_score_l)

                if len(pred_bbox_l) == 0:
                    continue
                if len(gt_bbox_l) == 0:
                    self._match[l].extend((0,) * pred_bbox_l.shape[0])
                    continue

                # VOC evaluation follows integer typed bounding boxes.
                pred_bbox_l = pred_bbox_l.copy()
                pred_bbox_l[:, 2:] += 1
                gt_bbox_l = gt_bbox_l.copy()
                gt_bbox_l[:, 2:] += 1

                iou = bbox_iou(pred_bbox_l, gt_bbox_l)
                gt_index = iou.argmax(axis=1)
                # set -1 if there is no matching ground truth
                gt_index[iou.max(axis=1) < self.iou_thresh] = -1
                del iou

                selec = np.zeros(gt_bbox_l.shape[0], dtype=bool)
                for gt_idx in gt_index:
                    if gt_idx >= 0:
                        if gt_difficult_l[gt_idx]:
                            self._match[l].append(-1)
                        else:
                            if not selec[gt_idx]:
                                self._match[l].append(1)
                            else:
                                self._match[l].append(0)
                        selec[gt_idx] = True
                    else:
                        self._match[l].append(0)

    def _update(self):
        """ update num_inst and sum_metric """
        # recall, precs = self._recall_prec()
        for l in range(self.num):
            self._trues[l] = self._n_pos[l]
            self._positives[l] = len(self._match[l])
            self._true_positives[l] = np.sum(np.array(self._match[l])==1)
            self._false_positives[l] = np.sum(np.array(self._match[l])==0)
            self._false_negatives[l] = self._trues[l] - self._true_positives[l]
            self._precsions[l] = self._true_positives[l] / self._positives[l]
            self._recalls[l] = self._true_positives[l] / self._trues[l]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment