Skip to content

Instantly share code, notes, and snippets.

@dvgodoy
Created April 27, 2019 16:28
Show Gist options
  • Save dvgodoy/3d5060d84f7bad9f7e9a217ffc4977a3 to your computer and use it in GitHub Desktop.
Save dvgodoy/3d5060d84f7bad9f7e9a217ffc4977a3 to your computer and use it in GitHub Desktop.
from torch.utils.data.sampler import SubsetRandomSampler
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
x_tensor = torch.from_numpy(x).float()
y_tensor = torch.from_numpy(y).float()
dataset = TensorDataset(x_tensor, y_tensor)
train_loader = DataLoader(dataset=dataset, batch_size=16, sampler=train_sampler)
val_loader = DataLoader(dataset=dataset, batch_size=20, sampler=val_sampler)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment