Last active
December 30, 2019 11:56
-
-
Save georgepar/f8012af6b971593189ed9a759eac2172 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import gzip | |
import copy | |
from sklearn.preprocessing import LabelEncoder | |
from torch.utils.data import Dataset | |
from torch.utils.data import SubsetRandomSampler, DataLoader | |
# Combine similar classes and remove underrepresented classes | |
class_mapping = { | |
'Rock': 'Rock', | |
'Psych-Rock': 'Rock', | |
'Indie-Rock': None, | |
'Post-Rock': 'Rock', | |
'Psych-Folk': 'Folk', | |
'Folk': 'Folk', | |
'Metal': 'Metal', | |
'Punk': 'Metal', | |
'Post-Punk': None, | |
'Trip-Hop': 'Trip-Hop', | |
'Pop': 'Pop', | |
'Electronic': 'Electronic', | |
'Hip-Hop': 'Hip-Hop', | |
'Classical': 'Classical', | |
'Blues': 'Blues', | |
'Chiptune': 'Electronic', | |
'Jazz': 'Jazz', | |
'Soundtrack': None, | |
'International': None, | |
'Old-Time': None | |
} | |
# TODO: Comment on howv the train and validation splits are created. | |
# TODO: It's useful to set the seed when debugging but when experimenting ALWAYS set seed=None. Why? | |
def torch_train_val_split( | |
dataset, batch_train, batch_eval, | |
val_size=.2, shuffle=True, seed=None): | |
# Creating data indices for training and validation splits: | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
val_split = int(np.floor(val_size * dataset_size)) | |
if shuffle: | |
np.random.seed(seed) | |
np.random.shuffle(indices) | |
train_indices = indices[val_split:] | |
val_indices = indices[:val_split] | |
# Creating PT data samplers and loaders: | |
train_sampler = SubsetRandomSampler(train_indices) | |
val_sampler = SubsetRandomSampler(val_indices) | |
train_loader = DataLoader(dataset, | |
batch_size=batch_train, | |
sampler=train_sampler) | |
val_loader = DataLoader(dataset, | |
batch_size=batch_eval, | |
sampler=val_sampler) | |
return train_loader, val_loader | |
def read_spectrogram(spectrogram_file, chroma=True): | |
with gzip.GzipFile(spectrogram_file, 'r') as f: | |
spectrograms = np.load(f) | |
# spectrograms contains a fused mel spectrogram and chromagram | |
# Decompose as follows | |
return spectrograms.T | |
class LabelTransformer(LabelEncoder): | |
def inverse(self, y): | |
try: | |
return super(LabelTransformer, self).inverse_transform(y) | |
except: | |
return super(LabelTransformer, self).inverse_transform([y]) | |
def transform(self, y): | |
try: | |
return super(LabelTransformer, self).transform(y) | |
except: | |
return super(LabelTransformer, self).transform([y]) | |
# TODO: Comment on why padding is needed | |
class PaddingTransform(object): | |
def __init__(self, max_length, padding_value=0): | |
self.max_length = max_length | |
self.padding_value = padding_value | |
def __call__(self, s): | |
if len(s) == self.max_length: | |
return s | |
if len(s) > self.max_length: | |
return s[:self.max_length] | |
if len(s) < self.max_length: | |
s1 = copy.deepcopy(s) | |
pad = np.zeros((self.max_length - s.shape[0], s.shape[1]), dtype=np.float32) | |
s1 = np.vstack((s1, pad)) | |
return s1 | |
class SpectrogramDataset(Dataset): | |
def __init__(self, path, class_mapping=None, train=True, max_length=-1): | |
t = 'train' if train else 'test' | |
p = os.path.join(path, t) | |
self.index = os.path.join(path, "{}_labels.txt".format(t)) | |
self.files, labels = self.get_files_labels(self.index, class_mapping) | |
self.feats = [read_spectrogram(os.path.join(p, f)) for f in self.files] | |
self.feat_dim = self.feats[0].shape[1] | |
self.lengths = [len(i) for i in self.feats] | |
self.max_length = max(self.lengths) if max_length <= 0 else max_length | |
self.zero_pad_and_stack = PaddingTransform(self.max_length) | |
self.label_transformer = LabelTransformer() | |
if isinstance(labels, (list, tuple)): | |
self.labels = np.array(self.label_transformer.fit_transform(labels)).astype('int64') | |
def get_files_labels(self, txt, class_mapping): | |
with open(txt, 'r') as fd: | |
lines = [l.rstrip().split('\t') for l in fd.readlines()[1:]] | |
files, labels = [], [] | |
for l in lines: | |
label = l[1] | |
if class_mapping: | |
label = class_mapping[l[1]] | |
if not label: | |
continue | |
files.append(l[0]) | |
labels.append(label) | |
return files, labels | |
def __getitem__(self, item): | |
# TODO: Inspect output and comment on how the output is formatted | |
l = min(self.lengths[item], self.max_length) | |
return self.zero_pad_and_stack(self.feats[item]), self.labels[item], l | |
def __len__(self): | |
return len(self.labels) | |
if __name__ == '__main__': | |
specs = SpectrogramDataset('../input/data/data/fma_genre_spectrograms/', train=True, class_mapping=class_mapping, max_length=-1) | |
train_loader, val_loader = torch_train_val_split(specs, 32 ,32, val_size=.33) | |
ttest_loader = DataLoader(SpectrogramDataset('../input/data/data/fma_genre_spectrograms/', train=False, class_mapping=class_mapping, max_length=-1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment