Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active June 30, 2021 11:14
Show Gist options
  • Save sadimanna/07729a36aca588ccf53a15f4723dee8a to your computer and use it in GitHub Desktop.
Save sadimanna/07729a36aca588ccf53a15f4723dee8a to your computer and use it in GitHub Desktop.
class DSDataGen(Dataset):
def __init__(self, phase, imgarr,labels,num_classes):
self.phase = phase
self.num_classes = num_classes
self.imgarr = imgarr
self.labels = labels
self.randomcrop = transforms.RandomResizedCrop(32,(0.8,1.0))
def __len__(self):
return self.imgarr.shape[0]
def __getitem__(self,idx):
x = self.imgarr[idx]
img = torch.from_numpy(x).float()
label = self.labels[idx]
if self.phase == 'train':
img = self.randomcrop(img)
img = self.preprocess(img)
return img, label
def on_epoch_end(self):
idx = random.sample(population = list(range(self.__len__())),k = self.__len__())
self.imgarr = self.imgarr[idx]
self.labels = self.labels[idx]
def preprocess(self,frame):
frame = frame / 255.0
frame = (frame-MEAN)/STD
return frame
dg = DSDataGen('train', trimages, trlabels, num_classes=10)
dl = DataLoader(dg,batch_size = 32, drop_last = True)
vdg = DSDataGen('valid', valimages, vallabels, num_classes=10)
vdl = DataLoader(vdg,batch_size = 32, drop_last = True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment