Skip to content

Instantly share code, notes, and snippets.

@level14taken
Last active December 29, 2020 11:00
Show Gist options
  • Save level14taken/0044b1dfcf96e77df13cc557f494b9a2 to your computer and use it in GitHub Desktop.
Save level14taken/0044b1dfcf96e77df13cc557f494b9a2 to your computer and use it in GitHub Desktop.
class TGSSaltDataset(data.Dataset):
def __init__(self, root_path, file_list, is_test = False,augment= def_transforms,dpsv=False):
self.is_test = is_test
self.root_path = root_path
self.file_list = file_list
self.augment= augment
self.dpsv= dpsv
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
file_id = self.file_list[index]
image_folder = os.path.join(self.root_path, "images")
image_path = os.path.join(image_folder, file_id + ".png")
mask_folder = os.path.join(self.root_path, "masks")
mask_path = os.path.join(mask_folder, file_id + ".png")
image = PIL.Image.open(image_path).convert('RGB')
image= np.array(image).astype(np.float32)
image=np.clip(image - np.median(image) +127, 0, 255)#remove all the noise which inherently exists in the data
if self.is_test:
data= {"image":image}
X= self.augment(**data)
image= X["image"]
return torch.FloatTensor(image).permute(2,0,1)
else:
mask= imread(mask_path)
mask= np.array(mask)
data= {"image":image,"mask":mask}
X= self.augment(**data)
if self.dpsv:
others=get_others(X['mask'])
return (torch.FloatTensor(X["image"]).permute(2,0,1),*others)
image,mask= torch.FloatTensor(X["image"]).permute(2,0,1),torch.FloatTensor(X["mask"]).unsqueeze(0)
return (image, mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment