Last active
March 27, 2019 16:26
-
-
Save botcs/72a221f8a95471155b25a9e655a654e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import glob | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import random | |
| from maskrcnn_benchmark.structures.bounding_box import BoxList | |
| from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask | |
| class DebugDataset(torch.utils.data.Dataset): | |
| """ | |
| A generic Dataset for the maskrcnn_benchmark must have the following | |
| non-trivial fields / methods implemented: | |
| classid_to_name - dict: | |
| This will allow the trivial generation of classid_to_ccid | |
| (contiguous) and ccid_to_classid (reversed) | |
| __getitem__ - function(idx): | |
| This has to return three things: img, target, idx. | |
| img is the input image, which has to be load as a PIL Image object | |
| implementing the target requires the most effort, since it must have | |
| multiple fields: the size, bounding boxes, labels (contiguous), and | |
| masks (either COCO-style Polygons, RLE or torch BinaryMask). | |
| Ideally the target is a BoxList instance with extra fields. | |
| Lastly, idx is simply the input argument of the function. | |
| also the following will be used: | |
| __len__ - function(): | |
| return the size of the dataset | |
| get_img_info - function(idx): | |
| return metadata, at least width and height of the input image | |
| """ | |
| def __init__( | |
| self, | |
| dataset_size, | |
| width=640, | |
| height=480, | |
| side_length=50, | |
| min_num_instances=3, | |
| max_num_instances=3, | |
| allow_overlap=False, | |
| transforms=None, | |
| ): | |
| """ | |
| Arguments: | |
| data_dir: /path/to/gta (dir supposed to contain train/val/test) | |
| split: "train" or "val" | |
| cityscapes_classes_only: remove classes not present in cityscapes | |
| no_empty_entries: remove images from dataset without instances | |
| mini: only return a given number of samples from the dataset | |
| for debugging purposes | |
| transforms: list of transforms to perform on dataset | |
| """ | |
| self.w = width | |
| self.h = height | |
| self.d = side_length | |
| self.min_inst = min_num_instances | |
| self.max_inst = max_num_instances | |
| self.allow_overlap = allow_overlap | |
| self.transforms = transforms | |
| self.classid_to_name = { | |
| #0: "unlabeled" - exclude background category, otherwise it will | |
| # influence the mAP score | |
| 23: "box" | |
| } | |
| self.imgs, self.anns = self.generateDebugBatch(dataset_size) | |
| self._initCCIDMaps() | |
| def generateDebugBatch(self, batch_size=2): | |
| """ | |
| This is just a specific implementation that replaces reading the input | |
| and the annotation images / files in. | |
| The boxes will be trivially generated around the masks that fits them | |
| the most precisely | |
| Note: the boxes are not allowed in this implementation to overlap | |
| """ | |
| w, h, d = self.w, self.h, self.d | |
| input = torch.zeros(batch_size, h, w, 3, dtype=torch.uint8) | |
| #mask = torch.zeros(batch_size, num_of_inst, h, w, dtype=torch.uint8) | |
| masks = [] | |
| def isOverlapping(box1, box2): | |
| x1min, y1min, x1max, y1max = box1 | |
| x2min, y2min, x2max, y2max = box2 | |
| val = (x1min < x2max and x2min < x1max and | |
| y1min < y2max and y2min < y1max) | |
| return val | |
| for i in range(batch_size): | |
| boxes = [] | |
| num_of_inst = random.randint(self.min_inst, self.max_inst) | |
| mask = torch.zeros(num_of_inst, h, w, dtype=torch.uint8) | |
| for n in range(num_of_inst): | |
| allowed_candidate = self.allow_overlap | |
| while not allowed_candidate: | |
| xmin = random.randint(0, w-d) | |
| ymin = random.randint(0, h-d) | |
| # perfect squares | |
| xmax = xmin + d | |
| ymax = ymin + d | |
| box_candidate = xmin, ymin, xmax, ymax | |
| allowed_candidate = all([ | |
| not isOverlapping(box_candidate, box) | |
| for box in boxes]) | |
| boxes.append(box_candidate) | |
| input[i, ymin:ymax, xmin:xmax] = 255 | |
| mask[n, ymin:ymax, xmin:xmax] = 255 | |
| masks.append(mask) | |
| return input, masks | |
| def __getitem__(self, idx): | |
| size = self.w, self.h | |
| img = Image.fromarray(self.imgs[idx].numpy(), mode="RGB") | |
| masks = self.anns[idx] | |
| boxes = [self._mask_to_tight_box(mask) for mask in masks] | |
| labels = self._extract_labels(idx) | |
| # Compose all into a BoxList instance | |
| target = BoxList(boxes, img.size, mode="xyxy") | |
| target.add_field("labels", labels) | |
| masks = SegmentationMask(masks, img.size, "mask") | |
| target.add_field("masks", masks) | |
| if self.transforms is not None: | |
| img, target = self.transforms(img, target) | |
| return img, target, idx | |
| def _extract_labels(self, idx): | |
| """ | |
| It is the user's task to implement label extraction from the external | |
| data. In this function for demonstration we are just using box labels. | |
| There are no restrictions on arguments, and if any simplification is | |
| available than it should be applied here. | |
| Returns: labels are expected to come as torch.LongTensors | |
| """ | |
| num_of_inst = len(self.anns[idx]) | |
| classid = self.name_to_classid["box"] | |
| ccid = self.classid_to_ccid[classid] | |
| labels = [ccid for _ in range(num_of_inst)] | |
| labels = torch.tensor(labels, dtype=torch.long) | |
| return labels | |
| def _mask_to_tight_box(self, mask): | |
| a = mask.nonzero() | |
| bbox = torch.min(a[:, 1]), torch.min(a[:, 0]), torch.max(a[:, 1]), torch.max(a[:, 0]) | |
| return bbox # x_min, y_min, x_max, y_max | |
| def _initCCIDMaps(self): | |
| # It is important that classid should not list the background | |
| # more important, that ccid must start from 1, since the backround is | |
| # always associated to ccid=0 by the training tools. | |
| self.classid_to_ccid = { | |
| classid: ccid | |
| for ccid, classid in enumerate(self.classid_to_name.keys(), 1) | |
| } | |
| self.ccid_to_classid = { | |
| ccid: classid | |
| for classid, ccid in self.classid_to_ccid.items() | |
| } | |
| self.name_to_classid = { | |
| name: classid | |
| for classid, name in self.classid_to_name.items() | |
| } | |
| def __len__(self): | |
| return len(self.imgs) | |
| def get_img_info(self, index): | |
| return { | |
| "width": self.w, | |
| "height": self.h, | |
| "idx": index, | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment