Skip to content

Instantly share code, notes, and snippets.

@andreaschandra
Created August 26, 2021 09:55
Show Gist options
  • Save andreaschandra/a55f3ee1190344e2b56d73318190be8e to your computer and use it in GitHub Desktop.
Save andreaschandra/a55f3ee1190344e2b56d73318190be8e to your computer and use it in GitHub Desktop.
import torch
from torch.utils.data import Dataset, DataLoader
# create dataset class
class Iris(Dataset):
def __init__(self, x_array, y_array):
self.x = x_array
self.y = y_array
def __getitem__(self, idx):
x = self.x[idx, :]
y = self.y[idx]
return x, y
def __len__(self):
return len(self.x)
# create random input and target
x = torch.rand(10, 4)
y = torch.rand(size=(4,))
print(x.shape, y.shape)
> torch.Size([10, 4]) torch.Size([4])
# call class Iris
dataset = Iris(x, y)
# test the class using iterator
x, y = next(iter(dataset))
print("x:", x)
print("y:", y)
> x: tensor([0.2306, 0.7369, 0.7853, 0.3919])
y: tensor(0.6016)
# using dataloader to load data per batch
data_generator = DataLoader(dataset, batch_size=2)
x, y = next(iter(data_generator))
print("x:", x)
print("y:", y)
> x: tensor([[0.2306, 0.7369, 0.7853, 0.3919],
[0.3299, 0.6081, 0.1574, 0.5445]])
y: tensor([0.6016, 0.4003])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment