Created
October 28, 2023 09:53
-
-
Save Bilaerl/e430ba6c551d6ecdf5dff39dbfabff75 to your computer and use it in GitHub Desktop.
Code Sample to share with Stickermule
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from torchvision.io import read_image | |
from torch.utils.data import Dataset | |
from pycocotools.coco import COCO | |
class EwdismDataset(Dataset): | |
""" Ewdism Dataset | |
Returns each example in the order img, target; where target is a dict of | |
boxes, masks, and labels if include_masks is True and a dict of boxes, and labels | |
if include_masks is False | |
Args: | |
annotation_file_path (str): path to annotation file to use. Annotation file | |
should be in the COCO format | |
image_folder_path (str): path to folder where images whose annotations are in the | |
annotation file are stored | |
include_masks (bool): whether to include masks in targets. Default is True | |
""" | |
def __init__(self, annotation_file_path, images_folder_path, include_masks=True): | |
super().__init__() | |
self.images_folder_path = images_folder_path | |
self.include_masks = include_masks | |
annotations = COCO(annotation_file_path) | |
images = annotations.loadImgs(annotations.getImgIds()) | |
labels = annotations.loadCats(annotations.getCatIds()) | |
anns = annotations.loadAnns(annotations.getAnnIds()) | |
self.id_to_labels = {label['id']:label['name'] for label in labels} | |
id_to_image_name = {img['id']:img['file_name'] for img in images} | |
imgs_with_annotation = {} | |
for annotation in anns: | |
img_id = annotation['image_id'] | |
box = annotation['bbox'] | |
label_id = annotation['category_id'] | |
if img_id not in imgs_with_annotation: | |
imgs_with_annotation[img_id] = {'image_name':id_to_image_name[img_id]} | |
if 'boxes' not in imgs_with_annotation[img_id]: | |
imgs_with_annotation[img_id]['boxes'] = [] | |
if 'labels' not in imgs_with_annotation[img_id]: | |
imgs_with_annotation[img_id]['labels'] = [] | |
box = self.reformat_bbox(box) | |
imgs_with_annotation[img_id]['boxes'].append(box) | |
imgs_with_annotation[img_id]['labels'].append(label_id) | |
if self.include_masks: | |
mask = annotations.annToMask(annotation) | |
if 'masks' not in imgs_with_annotation[img_id]: | |
imgs_with_annotation[img_id]['masks'] = [] | |
imgs_with_annotation[img_id]['masks'].append(mask) | |
self.data = list(imgs_with_annotation.values()) | |
self.len = len(self.data) | |
def __getitem__(self, idx): | |
""" Returns example at given index from the dataset | |
Args: | |
idx (int) --> index of example to fetch | |
Returns: | |
img (tensor) --> image at index as a tensor of pixels | |
targets (dict) --> dict of boxes, masks, and labels in image annotations | |
if include_masks is True or a dict of boxes, labels if include_masks | |
is False | |
""" | |
data = self.data[idx] | |
img = self.images_folder_path + data['image_name'] | |
img = read_image(img) | |
boxes = torch.tensor(data['boxes']) | |
labels = torch.tensor(data['labels']) | |
targets = {'boxes':boxes, 'labels':labels} | |
if self.include_masks: | |
masks = np.array(data['masks']) # converting directly to torch.Tensor is slow | |
masks = torch.from_numpy(masks) | |
targets['masks'] = masks | |
return img, targets | |
def __len__(self): | |
""" Returns number of examples in dataset | |
Args: | |
None | |
Returns: | |
self.len (int) --> Number of examples in dataset | |
""" | |
return self.len | |
def reformat_bbox(self, bbox): | |
""" Reformat COCO's bbox into one compatible with Pytorch | |
COCO's bbox is in the [xmin, ymin, width, height] format, | |
while Pytorch requires a [xmin, ymin, xmax, ymax] format. | |
Args: | |
bbox (list): --> bbox in COCO format | |
Returns: | |
bbox (list): --> bbox in Pytorch format | |
""" | |
xmin, ymin, width, height = bbox | |
xmax = xmin + width | |
ymax = ymin + height | |
bbox = [xmin, ymin, xmax, ymax] | |
return bbox |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment