Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Last active September 6, 2020 10:44
Show Gist options
  • Save MLWhiz/db1e6c02c0b228cae339226ba12ccfdb to your computer and use it in GitHub Desktop.
Save MLWhiz/db1e6c02c0b228cae339226ba12ccfdb to your computer and use it in GitHub Desktop.
from glob import glob
from PIL import Image
from torch.utils.data import Dataset
class customImageFolderDataset(Dataset):
"""Custom Image Loader dataset."""
def __init__(self, root, transform=None):
"""
Args:
root (string): Path to the images organized in a particular folder structure.
transform: Any Pytorch transform to be applied
"""
# Get all image paths from a directory
self.image_paths = glob(f"{root}/*/*")
# Get the labels from the image paths
self.labels = [x.split("/")[-2] for x in self.image_paths]
# Create a dictionary mapping each label to a index from 0 to len(classes).
self.label_to_idx = {x:i for i,x in enumerate(set(self.labels))}
self.transform = transform
def __len__(self):
# return length of dataset
return len(self.image_paths)
def __getitem__(self, idx):
# open and send one image and label
img_name = self.image_paths[idx]
label = self.labels[idx]
image = Image.open(img_name)
if self.transform:
image = self.transform(image)
return image,self.label_to_idx[label]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment