Skip to content

Instantly share code, notes, and snippets.

@christopher-beckham
Created March 11, 2019 01:37
Show Gist options
  • Save christopher-beckham/b5ca86c8a43e873137442e7211eb655f to your computer and use it in GitHub Desktop.
Save christopher-beckham/b5ca86c8a43e873137442e7211eb655f to your computer and use it in GitHub Desktop.
# Original source: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/stargan/datasets.py
# Modified by Christopher Beckham
import glob
import random
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
all_attrs = "5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young"
class CelebADataset(Dataset):
def __init__(self, root, transforms_=None, mode='train'):
assert mode in ['train', 'valid']
self.transform = transforms.Compose(transforms_)
self.files = sorted(glob.glob('%s/*.jpg' % root))
self.files = self.files[:-2000] if mode == 'train' else self.files[-2000:]
self.label_path = "%s/list_attr_celeba.txt" % root
self.missing_ind = missing_ind
self.annotations = self.get_annotations()
self.keys = list(self.annotations.keys())
self.selected_attrs = all_attrs.split()
def get_annotations(self):
"""Extracts annotations for CelebA"""
annotations = {}
lines = [line.rstrip() for line in open(self.label_path, 'r')]
self.label_names = lines[1].split()
for _, line in enumerate(lines[2:]):
filename, *values = line.split()
labels = []
for attr in self.selected_attrs:
idx = self.label_names.index(attr)
labels.append(1 * (values[idx] == '1'))
annotations[filename] = labels
return annotations
def __getitem__(self, index):
filepath = self.files[index % len(self.files)]
filename = filepath.split('/')[-1]
img = self.transform(Image.open(filepath))
label = self.annotations[filename]
label = torch.FloatTensor(np.array(label))
return img, label
def __len__(self):
return len(self.files)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment