Created
February 28, 2017 21:13
-
-
Save ncullen93/14b458cb4bd237bab2a41a185f710808 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Custom datasets from both in-memory and out-of-memory data | |
""" | |
import torch.utils.data as data | |
from PIL import Image | |
import os | |
import os.path | |
IMG_EXTENSIONS = [ | |
'.jpg', '.JPG', '.jpeg', '.JPEG', | |
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', | |
] | |
def is_image_file(filename): | |
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | |
def find_classes(dir): | |
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] | |
classes.sort() | |
class_to_idx = {classes[i]: i for i in range(len(classes))} | |
return classes, class_to_idx | |
def make_dataset(dir, class_to_idx): | |
images = [] | |
for target in os.listdir(dir): | |
d = os.path.join(dir, target) | |
if not os.path.isdir(d): | |
continue | |
for root, _, fnames in sorted(os.walk(d)): | |
for fname in fnames: | |
if is_image_file(fname): | |
path = os.path.join(root, fname) | |
item = (path, class_to_idx[target]) | |
images.append(item) | |
return images | |
def default_loader(path): | |
return Image.open(path).convert('RGB') | |
class FolderDataset(data.Dataset): | |
def __init__(self, | |
root, | |
transform=None, | |
target_transform=None, | |
co_transform=None, | |
loader=default_loader): | |
classes, class_to_idx = find_classes(root) | |
imgs = make_dataset(root, class_to_idx) | |
if len(imgs) == 0: | |
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" | |
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) | |
self.root = root | |
self.imgs = imgs | |
self.classes = classes | |
self.class_to_idx = class_to_idx | |
self.transform = transform | |
self.target_transform = target_transform | |
self.co_transform = co_transform | |
self.loader = loader | |
def __getitem__(self, index): | |
path, target = self.imgs[index] | |
img = self.loader(os.path.join(self.root, path)) | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
if self.co_transform is not None: | |
img, target = self.co_transform(img, target) | |
return img, target | |
def __len__(self): | |
return len(self.imgs) | |
class TensorDataset(data.Dataset): | |
def __init__(self, | |
input_tensor, | |
target_tensor, | |
transform=None, | |
target_transform=None, | |
co_transform=None): | |
assert input_tensor.size(0) == target_tensor.size(0) | |
self.input_tensor = input_tensor | |
self.target_tensor = target_tensor | |
self.transform = transform | |
self.target_transform = target_transform | |
self.co_transform = co_transform | |
def __getitem__(self, index): | |
img, target = self.input_tensor[index], self.target_tensor[index] | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
if self.co_transform is not None: | |
img, target = self.co_transform(img, target) | |
return img, target | |
def __len__(self): | |
return self.input_tensor.size(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment