Skip to content

Instantly share code, notes, and snippets.

@tkshnkmr
Last active November 5, 2019 13:26
Show Gist options
  • Save tkshnkmr/7399dbef38048e53923fc68fc8554827 to your computer and use it in GitHub Desktop.
Save tkshnkmr/7399dbef38048e53923fc68fc8554827 to your computer and use it in GitHub Desktop.
import os
import torch
from torch.utils import data
from PIL import Image
from torchvision import transforms
class simpleDataset(data.Dataset):
# initialise function of class
def __init__(self, root, filenames, labels):
# the data directory
self.root = root
# the list of filename
self.filenames = filenames
# the list of label
self.labels = labels
# obtain the sample with the given index
def __getitem__(self, index):
# obtain filenames from list
image_filename = self.filenames[index]
# Load data and label
image = Image.open(os.path.join(self.root, image_filename))
label = self.labels[index]
# output of Dataset must be tensor
image = transforms.ToTensor()(image)
label = torch.as_tensor(label, dtype=torch.int64)
return image, label
# the total number of samples (optional)
def __len__(self):
return len(self.filenames)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment