Skip to content

Instantly share code, notes, and snippets.

@georgepar
Created December 27, 2019 12:25
Show Gist options
  • Save georgepar/7f2851d78bbad453abf6efb759d08416 to your computer and use it in GitHub Desktop.
Save georgepar/7f2851d78bbad453abf6efb759d08416 to your computer and use it in GitHub Desktop.
Dataloading helper for Pattern Recognition Lab 3 in NTUA
import numpy as np
import gzip
import copy
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset
from torch.utils.data import SubsetRandomSampler, DataLoader
class_mapping = {
'Rock': 'Rock',
'Psych-Rock': 'Rock',
'Indie-Rock': None,
'Post-Rock': 'Rock',
'Psych-Folk': 'Folk',
'Folk': 'Folk',
'Metal': 'Metal',
'Punk': 'Metal',
'Post-Punk': None,
'Trip-Hop': 'Trip-Hop',
'Pop': 'Pop',
'Electronic': 'Electronic',
'Hip-Hop': 'Hip-Hop',
'Classical': 'Classical',
'Blues': 'Blues',
'Chiptune': 'Electronic',
'Jazz': 'Jazz',
'Soundtrack': None,
'International': None,
'Old-Time': None
}
def torch_train_val_split(
dataset, batch_train, batch_eval,
val_size=.2, shuffle=True, seed=42):
# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
val_split = int(np.floor(val_size * dataset_size))
if shuffle:
np.random.seed(seed)
np.random.shuffle(indices)
train_indices = indices[val_split:]
val_indices = indices[:val_split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
train_loader = DataLoader(dataset,
batch_size=batch_train,
sampler=train_sampler)
val_loader = DataLoader(dataset,
batch_size=batch_eval,
sampler=val_sampler)
return train_loader, val_loader
def read_spectrogram(spectrogram_file, chroma=True):
with gzip.GzipFile(spectrogram_file, 'r') as f:
spectrograms = np.load(f)
# spectrograms contains a fused mel spectrogram and chromagram
# Decompose as follows
return spectrograms.T
class LabelTransformer(LabelEncoder):
def inverse(self, y):
try:
return super(LabelTransformer, self).inverse_transform(y)
except:
return super(LabelTransformer, self).inverse_transform([y])
def transform(self, y):
try:
return super(LabelTransformer, self).transform(y)
except:
return super(LabelTransformer, self).transform([y])
class PaddingTransform(object):
def __init__(self, max_length, padding_value=0):
self.max_length = max_length
self.padding_value = padding_value
def __call__(self, s):
if len(s) == self.max_length:
return s
if len(s) > self.max_length:
return s[:self.max_length]
if len(s) < self.max_length:
s1 = copy.deepcopy(s)
pad = np.zeros((self.max_length - s.shape[0], s.shape[1]), dtype=np.float32)
s1 = np.vstack((s1, pad))
return s1
class SpectrogramDataset(Dataset):
def __init__(self, path, class_mapping=None, train=True, max_length=-1):
t = 'train' if train else 'test'
p = os.path.join(path, t)
self.index = os.path.join(path, "{}_labels.txt".format(t))
self.files, labels = self.get_files_labels(self.index, class_mapping)
self.feats = [read_spectrogram(os.path.join(p, f)) for f in self.files]
self.feat_dim = self.feats[0].shape[1]
self.lengths = [len(i) for i in self.feats]
self.max_length = max(self.lengths) if max_length <= 0 else max_length
self.zero_pad_and_stack = PaddingTransform(self.max_length)
self.label_transformer = LabelTransformer()
if isinstance(labels, (list, tuple)):
self.labels = np.array(self.label_transformer.fit_transform(labels)).astype('int64')
def get_files_labels(self, txt, class_mapping):
with open(txt, 'r') as fd:
lines = [l.rstrip().split('\t') for l in fd.readlines()[1:]]
files, labels = [], []
for l in lines:
label = l[1]
if class_mapping:
label = class_mapping[l[1]]
if not label:
continue
files.append(l[0])
labels.append(label)
return files, labels
def __getitem__(self, clip):
# TODO: Insert your code here
# Return a tuple containing (padded and stacked list of spectrograms for clip, label for clip, length of spectrogram list for clip)
raise NotImplementedError
def __len__(self):
# TODO: Insert your code here
# Return the length of the dataset (Number of samples in the dataset)
if __name__ == '__main__':
specs = SpectrogramDataset('../input/data/data/fma_genre_spectrograms/', train=True, class_mapping=class_mapping, max_length=-1)
train_loader, val_loader = torch_train_val_split(specs, 32 ,32, val_size=.33)
ttest_loader = DataLoader(SpectrogramDataset('../input/data/data/fma_genre_spectrograms/', train=False, class_mapping=class_mapping, max_length=-1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment