Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Last active July 20, 2022 12:55
Show Gist options
  • Save AlessandroMondin/bb6e9fa98bddef40c7cc394b52104ec2 to your computer and use it in GitHub Desktop.
Save AlessandroMondin/bb6e9fa98bddef40c7cc394b52104ec2 to your computer and use it in GitHub Desktop.
class DAVIS2017(Dataset):
"""DAVIS 2017 dataset constructed using the PyTorch built-in functionalities"""
def __init__(self, train=True,
db_root_dir=ROOT_DIR,
transform=None,
seq_name=None,
pad_mirroring=None):
"""Loads image to label pairs for tool pose estimation
db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations"
Parameters:
train (bool): if true the os.path.join will lead to the train set, otherwise to the val set
inputRes (tuple): image size after reshape (HEIGHT, WIDTH)
db_root_dir (path): path to the DAVIS2017 dataset
transform: set of Albumentation transformations to be performed with A.Compose
meanval (tuple): set of magic weights used for normalization (np.subtract(im, meanval))
seq_name (str): name of a class: i.e. if "bear" one im of "bear" class will be retrieved
"""
self.train = train
self.db_root_dir = db_root_dir
self.transform = transform
self.seq_name = seq_name
self.pad_mirroring = pad_mirroring
if self.train:
fname = 'train'
else:
fname = 'val'
if self.seq_name is None:
with open(os.path.join(db_root_dir, "ImageSets/2017", fname + '.txt')) as f:
seqs = f.readlines()
img_list = []
labels = []
for seq in seqs:
images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip())))
images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images))
img_list.extend(images_path)
lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip())))
lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab))
labels.extend(lab_path)
else:
names_img = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', str(seq_name))))
img_list = list(map(lambda x: os.path.join('JPEGImages/Full-Resolution/', str(seq_name), x), names_img))
name_label = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', str(seq_name))))
labels = [os.path.join('Annotations/480p/', str(seq_name), name_label[0])]
if self.train:
img_list = [img_list[0]]
labels = [labels[0]]
assert (len(labels) == len(img_list))
self.img_list = img_list
self.labels = labels
print('Done initializing ' + fname + ' Dataset')
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img = np.array(Image.open(os.path.join(self.db_root_dir, self.img_list[idx])).convert("RGB"), dtype=np.float32)
gt = np.array(Image.open(os.path.join(self.db_root_dir, self.labels[idx])).convert("L"), dtype=np.float32)
gt = gt.astype(np.bool).astype(np.float32)
if self.transform is not None:
augmentations = self.transform(image=img, mask=gt)
img = augmentations["image"]
gt = augmentations["mask"]
if self.pad_mirroring:
img = Pad(padding=self.pad_mirroring, padding_mode="reflect")(img)
return img, gt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment