Created
March 11, 2019 01:37
-
-
Save christopher-beckham/b5ca86c8a43e873137442e7211eb655f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Original source: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/stargan/datasets.py | |
# Modified by Christopher Beckham | |
import glob | |
import random | |
import os | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from PIL import Image | |
import torchvision.transforms as transforms | |
all_attrs = "5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young" | |
class CelebADataset(Dataset): | |
def __init__(self, root, transforms_=None, mode='train'): | |
assert mode in ['train', 'valid'] | |
self.transform = transforms.Compose(transforms_) | |
self.files = sorted(glob.glob('%s/*.jpg' % root)) | |
self.files = self.files[:-2000] if mode == 'train' else self.files[-2000:] | |
self.label_path = "%s/list_attr_celeba.txt" % root | |
self.missing_ind = missing_ind | |
self.annotations = self.get_annotations() | |
self.keys = list(self.annotations.keys()) | |
self.selected_attrs = all_attrs.split() | |
def get_annotations(self): | |
"""Extracts annotations for CelebA""" | |
annotations = {} | |
lines = [line.rstrip() for line in open(self.label_path, 'r')] | |
self.label_names = lines[1].split() | |
for _, line in enumerate(lines[2:]): | |
filename, *values = line.split() | |
labels = [] | |
for attr in self.selected_attrs: | |
idx = self.label_names.index(attr) | |
labels.append(1 * (values[idx] == '1')) | |
annotations[filename] = labels | |
return annotations | |
def __getitem__(self, index): | |
filepath = self.files[index % len(self.files)] | |
filename = filepath.split('/')[-1] | |
img = self.transform(Image.open(filepath)) | |
label = self.annotations[filename] | |
label = torch.FloatTensor(np.array(label)) | |
return img, label | |
def __len__(self): | |
return len(self.files) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment