Skip to content

Instantly share code, notes, and snippets.

@botcs
Last active March 27, 2019 16:26
Show Gist options
  • Select an option

  • Save botcs/72a221f8a95471155b25a9e655a654e1 to your computer and use it in GitHub Desktop.

Select an option

Save botcs/72a221f8a95471155b25a9e655a654e1 to your computer and use it in GitHub Desktop.
import os
import glob
from PIL import Image
import numpy as np
import torch
import torchvision
import random
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
class DebugDataset(torch.utils.data.Dataset):
"""
A generic Dataset for the maskrcnn_benchmark must have the following
non-trivial fields / methods implemented:
classid_to_name - dict:
This will allow the trivial generation of classid_to_ccid
(contiguous) and ccid_to_classid (reversed)
__getitem__ - function(idx):
This has to return three things: img, target, idx.
img is the input image, which has to be load as a PIL Image object
implementing the target requires the most effort, since it must have
multiple fields: the size, bounding boxes, labels (contiguous), and
masks (either COCO-style Polygons, RLE or torch BinaryMask).
Ideally the target is a BoxList instance with extra fields.
Lastly, idx is simply the input argument of the function.
also the following will be used:
__len__ - function():
return the size of the dataset
get_img_info - function(idx):
return metadata, at least width and height of the input image
"""
def __init__(
self,
dataset_size,
width=640,
height=480,
side_length=50,
min_num_instances=3,
max_num_instances=3,
allow_overlap=False,
transforms=None,
):
"""
Arguments:
data_dir: /path/to/gta (dir supposed to contain train/val/test)
split: "train" or "val"
cityscapes_classes_only: remove classes not present in cityscapes
no_empty_entries: remove images from dataset without instances
mini: only return a given number of samples from the dataset
for debugging purposes
transforms: list of transforms to perform on dataset
"""
self.w = width
self.h = height
self.d = side_length
self.min_inst = min_num_instances
self.max_inst = max_num_instances
self.allow_overlap = allow_overlap
self.transforms = transforms
self.classid_to_name = {
#0: "unlabeled" - exclude background category, otherwise it will
# influence the mAP score
23: "box"
}
self.imgs, self.anns = self.generateDebugBatch(dataset_size)
self._initCCIDMaps()
def generateDebugBatch(self, batch_size=2):
"""
This is just a specific implementation that replaces reading the input
and the annotation images / files in.
The boxes will be trivially generated around the masks that fits them
the most precisely
Note: the boxes are not allowed in this implementation to overlap
"""
w, h, d = self.w, self.h, self.d
input = torch.zeros(batch_size, h, w, 3, dtype=torch.uint8)
#mask = torch.zeros(batch_size, num_of_inst, h, w, dtype=torch.uint8)
masks = []
def isOverlapping(box1, box2):
x1min, y1min, x1max, y1max = box1
x2min, y2min, x2max, y2max = box2
val = (x1min < x2max and x2min < x1max and
y1min < y2max and y2min < y1max)
return val
for i in range(batch_size):
boxes = []
num_of_inst = random.randint(self.min_inst, self.max_inst)
mask = torch.zeros(num_of_inst, h, w, dtype=torch.uint8)
for n in range(num_of_inst):
allowed_candidate = self.allow_overlap
while not allowed_candidate:
xmin = random.randint(0, w-d)
ymin = random.randint(0, h-d)
# perfect squares
xmax = xmin + d
ymax = ymin + d
box_candidate = xmin, ymin, xmax, ymax
allowed_candidate = all([
not isOverlapping(box_candidate, box)
for box in boxes])
boxes.append(box_candidate)
input[i, ymin:ymax, xmin:xmax] = 255
mask[n, ymin:ymax, xmin:xmax] = 255
masks.append(mask)
return input, masks
def __getitem__(self, idx):
size = self.w, self.h
img = Image.fromarray(self.imgs[idx].numpy(), mode="RGB")
masks = self.anns[idx]
boxes = [self._mask_to_tight_box(mask) for mask in masks]
labels = self._extract_labels(idx)
# Compose all into a BoxList instance
target = BoxList(boxes, img.size, mode="xyxy")
target.add_field("labels", labels)
masks = SegmentationMask(masks, img.size, "mask")
target.add_field("masks", masks)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target, idx
def _extract_labels(self, idx):
"""
It is the user's task to implement label extraction from the external
data. In this function for demonstration we are just using box labels.
There are no restrictions on arguments, and if any simplification is
available than it should be applied here.
Returns: labels are expected to come as torch.LongTensors
"""
num_of_inst = len(self.anns[idx])
classid = self.name_to_classid["box"]
ccid = self.classid_to_ccid[classid]
labels = [ccid for _ in range(num_of_inst)]
labels = torch.tensor(labels, dtype=torch.long)
return labels
def _mask_to_tight_box(self, mask):
a = mask.nonzero()
bbox = torch.min(a[:, 1]), torch.min(a[:, 0]), torch.max(a[:, 1]), torch.max(a[:, 0])
return bbox # x_min, y_min, x_max, y_max
def _initCCIDMaps(self):
# It is important that classid should not list the background
# more important, that ccid must start from 1, since the backround is
# always associated to ccid=0 by the training tools.
self.classid_to_ccid = {
classid: ccid
for ccid, classid in enumerate(self.classid_to_name.keys(), 1)
}
self.ccid_to_classid = {
ccid: classid
for classid, ccid in self.classid_to_ccid.items()
}
self.name_to_classid = {
name: classid
for classid, name in self.classid_to_name.items()
}
def __len__(self):
return len(self.imgs)
def get_img_info(self, index):
return {
"width": self.w,
"height": self.h,
"idx": index,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment