Last active
July 20, 2022 12:55
-
-
Save AlessandroMondin/bb6e9fa98bddef40c7cc394b52104ec2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DAVIS2017(Dataset): | |
"""DAVIS 2017 dataset constructed using the PyTorch built-in functionalities""" | |
def __init__(self, train=True, | |
db_root_dir=ROOT_DIR, | |
transform=None, | |
seq_name=None, | |
pad_mirroring=None): | |
"""Loads image to label pairs for tool pose estimation | |
db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations" | |
Parameters: | |
train (bool): if true the os.path.join will lead to the train set, otherwise to the val set | |
inputRes (tuple): image size after reshape (HEIGHT, WIDTH) | |
db_root_dir (path): path to the DAVIS2017 dataset | |
transform: set of Albumentation transformations to be performed with A.Compose | |
meanval (tuple): set of magic weights used for normalization (np.subtract(im, meanval)) | |
seq_name (str): name of a class: i.e. if "bear" one im of "bear" class will be retrieved | |
""" | |
self.train = train | |
self.db_root_dir = db_root_dir | |
self.transform = transform | |
self.seq_name = seq_name | |
self.pad_mirroring = pad_mirroring | |
if self.train: | |
fname = 'train' | |
else: | |
fname = 'val' | |
if self.seq_name is None: | |
with open(os.path.join(db_root_dir, "ImageSets/2017", fname + '.txt')) as f: | |
seqs = f.readlines() | |
img_list = [] | |
labels = [] | |
for seq in seqs: | |
images = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', seq.strip()))) | |
images_path = list(map(lambda x: os.path.join('JPEGImages/480p/', seq.strip(), x), images)) | |
img_list.extend(images_path) | |
lab = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', seq.strip()))) | |
lab_path = list(map(lambda x: os.path.join('Annotations/480p/', seq.strip(), x), lab)) | |
labels.extend(lab_path) | |
else: | |
names_img = np.sort(os.listdir(os.path.join(db_root_dir, 'JPEGImages/480p/', str(seq_name)))) | |
img_list = list(map(lambda x: os.path.join('JPEGImages/Full-Resolution/', str(seq_name), x), names_img)) | |
name_label = np.sort(os.listdir(os.path.join(db_root_dir, 'Annotations/480p/', str(seq_name)))) | |
labels = [os.path.join('Annotations/480p/', str(seq_name), name_label[0])] | |
if self.train: | |
img_list = [img_list[0]] | |
labels = [labels[0]] | |
assert (len(labels) == len(img_list)) | |
self.img_list = img_list | |
self.labels = labels | |
print('Done initializing ' + fname + ' Dataset') | |
def __len__(self): | |
return len(self.img_list) | |
def __getitem__(self, idx): | |
img = np.array(Image.open(os.path.join(self.db_root_dir, self.img_list[idx])).convert("RGB"), dtype=np.float32) | |
gt = np.array(Image.open(os.path.join(self.db_root_dir, self.labels[idx])).convert("L"), dtype=np.float32) | |
gt = gt.astype(np.bool).astype(np.float32) | |
if self.transform is not None: | |
augmentations = self.transform(image=img, mask=gt) | |
img = augmentations["image"] | |
gt = augmentations["mask"] | |
if self.pad_mirroring: | |
img = Pad(padding=self.pad_mirroring, padding_mode="reflect")(img) | |
return img, gt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment