Skip to content

Instantly share code, notes, and snippets.

@Bilaerl
Created October 28, 2023 09:53
Show Gist options
  • Save Bilaerl/e430ba6c551d6ecdf5dff39dbfabff75 to your computer and use it in GitHub Desktop.
Save Bilaerl/e430ba6c551d6ecdf5dff39dbfabff75 to your computer and use it in GitHub Desktop.
Code Sample to share with Stickermule
import numpy as np
import torch
from torchvision.io import read_image
from torch.utils.data import Dataset
from pycocotools.coco import COCO
class EwdismDataset(Dataset):
""" Ewdism Dataset
Returns each example in the order img, target; where target is a dict of
boxes, masks, and labels if include_masks is True and a dict of boxes, and labels
if include_masks is False
Args:
annotation_file_path (str): path to annotation file to use. Annotation file
should be in the COCO format
image_folder_path (str): path to folder where images whose annotations are in the
annotation file are stored
include_masks (bool): whether to include masks in targets. Default is True
"""
def __init__(self, annotation_file_path, images_folder_path, include_masks=True):
super().__init__()
self.images_folder_path = images_folder_path
self.include_masks = include_masks
annotations = COCO(annotation_file_path)
images = annotations.loadImgs(annotations.getImgIds())
labels = annotations.loadCats(annotations.getCatIds())
anns = annotations.loadAnns(annotations.getAnnIds())
self.id_to_labels = {label['id']:label['name'] for label in labels}
id_to_image_name = {img['id']:img['file_name'] for img in images}
imgs_with_annotation = {}
for annotation in anns:
img_id = annotation['image_id']
box = annotation['bbox']
label_id = annotation['category_id']
if img_id not in imgs_with_annotation:
imgs_with_annotation[img_id] = {'image_name':id_to_image_name[img_id]}
if 'boxes' not in imgs_with_annotation[img_id]:
imgs_with_annotation[img_id]['boxes'] = []
if 'labels' not in imgs_with_annotation[img_id]:
imgs_with_annotation[img_id]['labels'] = []
box = self.reformat_bbox(box)
imgs_with_annotation[img_id]['boxes'].append(box)
imgs_with_annotation[img_id]['labels'].append(label_id)
if self.include_masks:
mask = annotations.annToMask(annotation)
if 'masks' not in imgs_with_annotation[img_id]:
imgs_with_annotation[img_id]['masks'] = []
imgs_with_annotation[img_id]['masks'].append(mask)
self.data = list(imgs_with_annotation.values())
self.len = len(self.data)
def __getitem__(self, idx):
""" Returns example at given index from the dataset
Args:
idx (int) --> index of example to fetch
Returns:
img (tensor) --> image at index as a tensor of pixels
targets (dict) --> dict of boxes, masks, and labels in image annotations
if include_masks is True or a dict of boxes, labels if include_masks
is False
"""
data = self.data[idx]
img = self.images_folder_path + data['image_name']
img = read_image(img)
boxes = torch.tensor(data['boxes'])
labels = torch.tensor(data['labels'])
targets = {'boxes':boxes, 'labels':labels}
if self.include_masks:
masks = np.array(data['masks']) # converting directly to torch.Tensor is slow
masks = torch.from_numpy(masks)
targets['masks'] = masks
return img, targets
def __len__(self):
""" Returns number of examples in dataset
Args:
None
Returns:
self.len (int) --> Number of examples in dataset
"""
return self.len
def reformat_bbox(self, bbox):
""" Reformat COCO's bbox into one compatible with Pytorch
COCO's bbox is in the [xmin, ymin, width, height] format,
while Pytorch requires a [xmin, ymin, xmax, ymax] format.
Args:
bbox (list): --> bbox in COCO format
Returns:
bbox (list): --> bbox in Pytorch format
"""
xmin, ymin, width, height = bbox
xmax = xmin + width
ymax = ymin + height
bbox = [xmin, ymin, xmax, ymax]
return bbox
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment