Created
December 27, 2019 12:25
-
-
Save georgepar/7f2851d78bbad453abf6efb759d08416 to your computer and use it in GitHub Desktop.
Dataloading helper for Pattern Recognition Lab 3 in NTUA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import gzip | |
import copy | |
from sklearn.preprocessing import LabelEncoder | |
from torch.utils.data import Dataset | |
from torch.utils.data import SubsetRandomSampler, DataLoader | |
class_mapping = { | |
'Rock': 'Rock', | |
'Psych-Rock': 'Rock', | |
'Indie-Rock': None, | |
'Post-Rock': 'Rock', | |
'Psych-Folk': 'Folk', | |
'Folk': 'Folk', | |
'Metal': 'Metal', | |
'Punk': 'Metal', | |
'Post-Punk': None, | |
'Trip-Hop': 'Trip-Hop', | |
'Pop': 'Pop', | |
'Electronic': 'Electronic', | |
'Hip-Hop': 'Hip-Hop', | |
'Classical': 'Classical', | |
'Blues': 'Blues', | |
'Chiptune': 'Electronic', | |
'Jazz': 'Jazz', | |
'Soundtrack': None, | |
'International': None, | |
'Old-Time': None | |
} | |
def torch_train_val_split( | |
dataset, batch_train, batch_eval, | |
val_size=.2, shuffle=True, seed=42): | |
# Creating data indices for training and validation splits: | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
val_split = int(np.floor(val_size * dataset_size)) | |
if shuffle: | |
np.random.seed(seed) | |
np.random.shuffle(indices) | |
train_indices = indices[val_split:] | |
val_indices = indices[:val_split] | |
# Creating PT data samplers and loaders: | |
train_sampler = SubsetRandomSampler(train_indices) | |
val_sampler = SubsetRandomSampler(val_indices) | |
train_loader = DataLoader(dataset, | |
batch_size=batch_train, | |
sampler=train_sampler) | |
val_loader = DataLoader(dataset, | |
batch_size=batch_eval, | |
sampler=val_sampler) | |
return train_loader, val_loader | |
def read_spectrogram(spectrogram_file, chroma=True): | |
with gzip.GzipFile(spectrogram_file, 'r') as f: | |
spectrograms = np.load(f) | |
# spectrograms contains a fused mel spectrogram and chromagram | |
# Decompose as follows | |
return spectrograms.T | |
class LabelTransformer(LabelEncoder): | |
def inverse(self, y): | |
try: | |
return super(LabelTransformer, self).inverse_transform(y) | |
except: | |
return super(LabelTransformer, self).inverse_transform([y]) | |
def transform(self, y): | |
try: | |
return super(LabelTransformer, self).transform(y) | |
except: | |
return super(LabelTransformer, self).transform([y]) | |
class PaddingTransform(object): | |
def __init__(self, max_length, padding_value=0): | |
self.max_length = max_length | |
self.padding_value = padding_value | |
def __call__(self, s): | |
if len(s) == self.max_length: | |
return s | |
if len(s) > self.max_length: | |
return s[:self.max_length] | |
if len(s) < self.max_length: | |
s1 = copy.deepcopy(s) | |
pad = np.zeros((self.max_length - s.shape[0], s.shape[1]), dtype=np.float32) | |
s1 = np.vstack((s1, pad)) | |
return s1 | |
class SpectrogramDataset(Dataset): | |
def __init__(self, path, class_mapping=None, train=True, max_length=-1): | |
t = 'train' if train else 'test' | |
p = os.path.join(path, t) | |
self.index = os.path.join(path, "{}_labels.txt".format(t)) | |
self.files, labels = self.get_files_labels(self.index, class_mapping) | |
self.feats = [read_spectrogram(os.path.join(p, f)) for f in self.files] | |
self.feat_dim = self.feats[0].shape[1] | |
self.lengths = [len(i) for i in self.feats] | |
self.max_length = max(self.lengths) if max_length <= 0 else max_length | |
self.zero_pad_and_stack = PaddingTransform(self.max_length) | |
self.label_transformer = LabelTransformer() | |
if isinstance(labels, (list, tuple)): | |
self.labels = np.array(self.label_transformer.fit_transform(labels)).astype('int64') | |
def get_files_labels(self, txt, class_mapping): | |
with open(txt, 'r') as fd: | |
lines = [l.rstrip().split('\t') for l in fd.readlines()[1:]] | |
files, labels = [], [] | |
for l in lines: | |
label = l[1] | |
if class_mapping: | |
label = class_mapping[l[1]] | |
if not label: | |
continue | |
files.append(l[0]) | |
labels.append(label) | |
return files, labels | |
def __getitem__(self, clip): | |
# TODO: Insert your code here | |
# Return a tuple containing (padded and stacked list of spectrograms for clip, label for clip, length of spectrogram list for clip) | |
raise NotImplementedError | |
def __len__(self): | |
# TODO: Insert your code here | |
# Return the length of the dataset (Number of samples in the dataset) | |
if __name__ == '__main__': | |
specs = SpectrogramDataset('../input/data/data/fma_genre_spectrograms/', train=True, class_mapping=class_mapping, max_length=-1) | |
train_loader, val_loader = torch_train_val_split(specs, 32 ,32, val_size=.33) | |
ttest_loader = DataLoader(SpectrogramDataset('../input/data/data/fma_genre_spectrograms/', train=False, class_mapping=class_mapping, max_length=-1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment