Skip to content

Instantly share code, notes, and snippets.

@conormm
Created May 3, 2018 21:57
Show Gist options
  • Save conormm/58d64b3453f7b21adc443798fb58afdd to your computer and use it in GitHub Desktop.
Save conormm/58d64b3453f7b21adc443798fb58afdd to your computer and use it in GitHub Desktop.
class PrepareData(Dataset):
def __init__(self, X, y):
if not torch.is_tensor(X):
self.X = torch.from_numpy(X)
if not torch.is_tensor(y):
self.y = torch.from_numpy(y)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
ds = PrepareData(X=X, y=y)
ds = DataLoader(ds, batch_size=50, shuffle=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment