Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active June 30, 2021 10:31
Show Gist options
  • Select an option

  • Save sadimanna/c247acde2edbdd744182b0789acd31d6 to your computer and use it in GitHub Desktop.

Select an option

Save sadimanna/c247acde2edbdd744182b0789acd31d6 to your computer and use it in GitHub Desktop.
class C10DataGen(Dataset):
def __init__(self,phase,imgarr,s = 0.5):
self.phase = phase
self.imgarr = imgarr
self.s = s
self.transforms = transforms.Compose([transforms.RandomHorizontalFlip(0.5),
transforms.RandomResizedCrop(32,(0.8,1.0)),
transforms.Compose([transforms.RandomApply([transforms.ColorJitter(0.8*self.s,
0.8*self.s,
0.8*self.s,
0.2*self.s)], p = 0.8),
transforms.RandomGrayscale(p=0.2)
])])
def __len__(self):
return self.imgarr.shape[0]
def __getitem__(self,idx):
x = self.imgarr[idx]
#print(x.shape)
x = x.astype(np.float32)/255.0
x1 = self.augment(torch.from_numpy(x))
x2 = self.augment(torch.from_numpy(x))
x1 = self.preprocess(x1)
x2 = self.preprocess(x2)
return x1, x2
#shuffles the dataset at the end of each epoch
def on_epoch_end(self):
self.imgarr = self.imgarr[random.sample(population = list(range(self.__len__())),k = self.__len__())]
def preprocess(self,frame):
frame = (frame-MEAN)/STD
return frame
#applies randomly selected augmentations to each clip (same for each frame in the clip)
def augment(self, frame, transformations = None):
if self.phase == 'train':
frame = self.transforms(frame)
else:
return frame
return frame
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment