Created
March 4, 2019 14:40
-
-
Save botcs/5d13a744104ab1fa9fdd9987ea7ff97a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class COCOWrapper(object): | |
| """ | |
| Mimics ONLY the basic utilities from pycocotools.coco.COCO class which are | |
| required for and used by pycocotools.coco.COCOeval | |
| The implementation focuses to cover the bare minimum to make the script | |
| running. | |
| """ | |
| def __init__(self, dataset, every_prediction=None): | |
| # follow COCO notation: gt -> ground truth, dt -> detection | |
| # COCO API requires data to be held in the memory throughout the eval | |
| # so fingers crossed that segmentation masks converted to RLE can fit in | |
| self.dataset = dataset | |
| self.every_prediction = every_prediction | |
| if every_prediction is None: | |
| # This COCOWrapper instance will hold only GT annotations | |
| self.gt = self._buildCOCOAnnotations() | |
| self.dt = None | |
| else: | |
| # This COCOWrapper instance will hold only predictions | |
| self.gt = None | |
| self.dt = self._buildCOCOPredictions() | |
| def getAnnIds(self, *args, **kwargs): | |
| # AnnIds is not a necessary thing | |
| return | |
| def getCatIds(self, *args, **kwargs): | |
| return list(self.dataset.classid_to_name) | |
| def getImgIds(self, *args, **kwargs): | |
| # ImgIds is not a necessary thing, so just send back a list from [0..N] | |
| return list(range(len(self.dataset))) | |
| def _buildCOCOAnnotations(self): | |
| print("Building COCO GT annots", flush=True) | |
| desc = "Parsing images" | |
| coco_anns = [] | |
| for image_id in tqdm(range(len(self.dataset)), desc=desc): | |
| _, anns, _ = self.dataset[image_id] | |
| for inst_idx in range(len(anns)): | |
| # TODO: find out why BoxList indexing would be a problem. | |
| # Only ranges can be applied ATM. | |
| ann = anns[inst_idx:inst_idx+1] | |
| ann = { | |
| "id": inst_idx, | |
| "image_id": image_id, | |
| "size": ann.size, | |
| "bbox": ann.bbox[0].tolist(), | |
| "area": ann.area().item(), | |
| "category_id": ann.get_field("labels").item(), | |
| "segmentation": ann.get_field("masks"), | |
| "iscrowd": 0 | |
| } | |
| ann["segmentation"] = self.annToRLE(ann) | |
| coco_anns.append(ann) | |
| return coco_anns | |
| def _buildCOCOPredictions(self): | |
| print("Building COCO Predictions", flush=True) | |
| desc = "Parsing images" | |
| coco_preds = [] | |
| for image_id, predictions in tqdm(enumerate(self.every_prediction), desc=desc): | |
| if len(predictions) == 0: | |
| continue | |
| img_info = self.dataset.get_img_info(image_id) | |
| width = img_info["width"] | |
| height = img_info["height"] | |
| if predictions.size[0] != width or predictions.size[1] != height: | |
| predictions = predictions.resize(size=(width, height)) | |
| for inst_idx in range(len(predictions)): | |
| pred = predictions[inst_idx:inst_idx+1] | |
| pred = { | |
| "id": inst_idx, | |
| "image_id": image_id, | |
| "size": pred.size, | |
| "bbox": pred.bbox[0].tolist(), | |
| "area": pred.area().item(), | |
| "segmentation": pred.get_field("masks"), | |
| "category_id": pred.get_field("labels").item(), | |
| "score": pred.get_field("scores").item(), # preds differ here | |
| "iscrowd": 0 | |
| } | |
| pred["segmentation"] = self.annToRLE(pred) | |
| coco_preds.append(pred) | |
| return coco_preds | |
| def loadAnns(self, *args, **kwargs): | |
| if self.dt is None: | |
| return self.gt | |
| else: | |
| return self.dt | |
| def annToRLE(self, ann): | |
| segm = ann['segmentation'] | |
| h, w = ann['size'] | |
| if isinstance(segm, dict) and "counts" in segm.keys(): | |
| # already rle | |
| rle = segm | |
| elif isinstance(segm, SegmentationMask) and segm.mode == 'poly': | |
| segm = ann.instances.polygons | |
| # polygon -- a single object might consist of multiple parts | |
| # we merge all parts into one mask rle code | |
| rles = mask_utils.frPyObjects(segm, h, w) | |
| rle = mask_utils.merge(rles) | |
| elif isinstance(segm, SegmentationMask) and segm.mode == 'mask': | |
| np_mask = np.array(segm.instances.masks[0, :, :, None], order="F") | |
| rle = mask_utils.encode(np_mask)[0] | |
| else: | |
| raise RuntimeError("Unknown segmentation format: %s"%segm) | |
| return rle |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
^ thanks! i just wasted a few hours because my annotation id's started at zero and not one...