Last active
July 20, 2017 07:05
-
-
Save harveyslash/8e8256bd1e260c239c4e6c5460e8641c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SiameseNetworkDataset(Dataset): | |
def __init__(self,imageFolderDataset,transform=None,should_invert=True): | |
self.imageFolderDataset = imageFolderDataset | |
self.transform = transform | |
self.should_invert = should_invert | |
def __getitem__(self,index): | |
img0_tuple = random.choice(self.imageFolderDataset.imgs) | |
#we need to make sure approx 50% of images are in the same class | |
should_get_same_class = random.randint(0,1) | |
if should_get_same_class: | |
while True: | |
#keep looping till the same class image is found | |
img1_tuple = random.choice(self.imageFolderDataset.imgs) | |
if img0_tuple[1]==img1_tuple[1]: | |
break | |
else: | |
img1_tuple = random.choice(self.imageFolderDataset.imgs) | |
img0 = Image.open(img0_tuple[0]) | |
img1 = Image.open(img1_tuple[0]) | |
img0 = img0.convert("L") | |
img1 = img1.convert("L") | |
if self.should_invert: | |
img0 = PIL.ImageOps.invert(img0) | |
img1 = PIL.ImageOps.invert(img1) | |
if self.transform is not None: | |
img0 = self.transform(img0) | |
img1 = self.transform(img1) | |
return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32)) | |
def __len__(self): | |
return len(self.imageFolderDataset.imgs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment