Skip to content

Instantly share code, notes, and snippets.

@syaffers
Created May 15, 2019 07:19
Show Gist options
  • Save syaffers/5f63e07e05ca510fbb087618416437b8 to your computer and use it in GitHub Desktop.
Save syaffers/5f63e07e05ca510fbb087618416437b8 to your computer and use it in GitHub Desktop.
The second iteration of the TES names dataset with major additions and updates
import os
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset
import torch
class TESNamesDataset(Dataset):
def __init__(self, data_root, charset):
self.data_root = data_root
self.charset = charset
self.samples = []
self.race_codec = LabelEncoder()
self.gender_codec = LabelEncoder()
self.char_codec = LabelEncoder()
self._init_dataset()
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
race, gender, name = self.samples[idx]
return self.one_hot_sample(race, gender, name)
def _init_dataset(self):
races = set()
genders = set()
for race in os.listdir(self.data_root):
race_folder = os.path.join(self.data_root, race)
races.add(race)
for gender in os.listdir(race_folder):
gender_filepath = os.path.join(race_folder, gender)
genders.add(gender)
with open(gender_filepath, 'r') as gender_file:
for name in gender_file.read().splitlines():
self.samples.append((race, gender, name))
self.race_codec.fit(list(races))
self.gender_codec.fit(list(genders))
self.char_codec.fit(list(self.charset))
def to_one_hot(self, codec, values):
value_idxs = codec.transform(values)
return torch.eye(len(codec.classes_))[value_idxs]
def one_hot_sample(self, race, gender, name):
t_race = self.to_one_hot(self.race_codec, [race])
t_gender = self.to_one_hot(self.gender_codec, [gender])
t_name = self.to_one_hot(self.char_codec, list(name))
return t_race, t_gender, t_name
if __name__ == '__main__':
import string
data_root = '/home/syafiq/Data/tes-names/'
charset = string.ascii_letters + "-' "
dataset = TESNamesDataset(data_root, charset)
print(len(dataset))
print(dataset[420])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment