Created
September 19, 2018 16:34
-
-
Save christopher-beckham/7fa3b258bc9ba361b921af407a051303 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import glob | |
import random | |
import os | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from PIL import Image | |
import torchvision.transforms as transforms | |
class CelebADataset(Dataset): | |
def __init__(self, root, transforms_=None, mode='train', | |
attributes=[], missing_ind=False): | |
self.transform = transforms.Compose(transforms_) | |
self.selected_attrs = attributes | |
self.files = sorted(glob.glob('%s/*.jpg' % root)) | |
self.files = self.files[:-2000] if mode == 'train' else self.files[-2000:] | |
self.label_path = "%s/list_attr_celeba.txt" % root | |
self.missing_ind = missing_ind | |
self.annotations = self.get_annotations() | |
self.keys = list(self.annotations.keys()) | |
def get_annotations(self): | |
"""Extracts annotations for CelebA""" | |
annotations = {} | |
lines = [line.rstrip() for line in open(self.label_path, 'r')] | |
self.label_names = lines[1].split() | |
for _, line in enumerate(lines[2:]): | |
filename, *values = line.split() | |
labels = [] | |
for attr in self.selected_attrs: | |
idx = self.label_names.index(attr) | |
labels.append(1 * (values[idx] == '1')) | |
if self.missing_ind: | |
# Basically add a label saying this is the | |
# 'everything else' class. | |
if 1 not in labels: | |
labels.append(1) | |
else: | |
labels.append(0) | |
annotations[filename] = labels | |
return annotations | |
def __getitem__(self, index): | |
filepath = self.files[index % len(self.files)] | |
filename = filepath.split('/')[-1] | |
img = self.transform(Image.open(filepath)) | |
label = self.annotations[filename] | |
label = torch.FloatTensor(np.array(label)) | |
if len(self.selected_attrs) == 0: | |
return img | |
else: | |
return img, label | |
def __len__(self): | |
return len(self.files) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
root
-> this is the folder with the contents ofimg_align_celeba.zip
list_attr_celeba.txt
-> https://github.com/andrewliao11/CoGAN-tensorflow/blob/master/list_attr_celeba.txt (this must be in the same folder as that defined byroot
)