-
-
Save andrewjong/6b02ff237533b3b2c554701fb53d5c4d to your computer and use it in GitHub Desktop.
import torch | |
from torchvision import datasets | |
class ImageFolderWithPaths(datasets.ImageFolder): | |
"""Custom dataset that includes image file paths. Extends | |
torchvision.datasets.ImageFolder | |
""" | |
# override the __getitem__ method. this is the method that dataloader calls | |
def __getitem__(self, index): | |
# this is what ImageFolder normally returns | |
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index) | |
# the image file path | |
path = self.imgs[index][0] | |
# make a new tuple that includes original and the path | |
tuple_with_path = (original_tuple + (path,)) | |
return tuple_with_path | |
# EXAMPLE USAGE: | |
# instantiate the dataset and dataloader | |
data_dir = "your/data_dir/here" | |
dataset = ImageFolderWithPaths(data_dir) # our custom dataset | |
dataloader = torch.utils.DataLoader(dataset) | |
# iterate over data | |
for inputs, labels, paths in dataloader: | |
# use the above variables freely | |
print(inputs, labels, paths) |
I got the same error.
The issue is related the images folder location but I'm unable to figure it out.
@flydragon2018 you need to add ToTensor()
to your augmentation pipeline.
@a7906375 @tehreemnaqvi if your data_dir
is a pathlib.Path
, you need to apply str()
before passing it to ImageFolderWithPaths
Here is a concise version that I can confirm works
class ImageFolderWithPaths(ImageFolder):
def __getitem__(self, index):
return super(ImageFolderWithPaths, self).__getitem__(index) + (self.imgs[index][0],)
`import torch
from torchvision import *
transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
class ImageFolderWithPaths(datasets.ImageFolder):
def getitem(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).getitem(index)
# the image file path
path = self.imgs[index][0]
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
data_dir = "./sig_datasets/"
dataset = ImageFolderWithPaths(data_dir, transform=transforms)
dataloader = torch.utils.data.DataLoader(dataset)
iterate over data
for i, data in enumerate(dataloader):
images,labels,paths = data
print(images)
break`
This code worked for me.
Works out of the box. Thanks!
You are my hero! thank you!
How would I modify this to isolate files with a wildcard? For example if I wanted to isolate all image files that start with vid_1234.
Wondeeful! You save my day!
Thanks; hard to imagine that ImageFolder doesn't have this function / flag
May I ask under what license this snippet is released?
import torch
import torchvision
from torchvision import datasets, transforms
transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
#override the getitem method. this is the method that dataloader calls
def getitem(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).getitem(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
#EXAMPLE USAGE:
#instantiate the dataset and dataloader
data_dir = './dog_vs_cat/train/'
dataset = ImageFolderWithPaths(data_dir, transform=transforms) # our custom dataset
dataloader = torch.utils.data.DataLoader(dataset)
#iterate over data
for inputs, labels, paths in dataloader:
# use the above variables freely
print(inputs, labels, paths)
This code worked for me.
I am getting this error message: