Last active
October 11, 2020 15:13
-
-
Save InnovArul/1f810b3739df4c5431c6be819f8795fe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_dataloader_from_pth(path, batch_size=4): | |
contents = torch.load(path) | |
dataset = torch.utils.data.TensorDataset(contents['x'], contents['y']) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, | |
shuffle=True, num_workers=2) | |
return dataloader | |
#---------------------------------------------------------------------- | |
import os.path as osp | |
datasetpath = '<<<ROOT>>>' | |
train_pth = osp.join(datasetpath, 'train.pth') | |
val_pth = osp.join(datasetpath, 'val.pth') | |
test_pth = osp.join(datasetpath, 'test.pth') | |
#------------------------------------------------------------------------ | |
trainloader = get_dataloader_from_pth(train_pth) | |
valloader = get_dataloader_from_pth(val_pth) | |
testloader = get_dataloader_from_pth(test_pth) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment