Skip to content

Instantly share code, notes, and snippets.

@georgepar
Created December 20, 2018 15:25
Show Gist options
  • Save georgepar/8a8f75c87732159b3699fb46d9523d78 to your computer and use it in GitHub Desktop.
Save georgepar/8a8f75c87732159b3699fb46d9523d78 to your computer and use it in GitHub Desktop.
from torch.utils.data import Dataset
from sklearn.preprocessing import LabelEncoder
class LabelTransformer(LabelEncoder):
def inverse(self, y):
try:
return super(LabelTransformer, self).inverse_transform(y)
except:
return super(LabelTransformer, self).inverse_transform([y])
def transform(self, y):
try:
return super(LabelTransformer, self).transform(y)
except:
return super(LabelTransformer, self).transform([y])
class SpectrogramDataset(Dataset):
def __init__(self, path, train=True):
t = 'train' if train else 'test'
p = os.path.join(path, t)
self.index = os.path.join(path, "{}_labels.txt".format(t))
self.files, labels = self.get_files_labels(self.index)
self.feats = [self.read_spectrogram(os.path.join(p, f)) for f in self.files]
self.feat_dim = self.feats[0].shape[1]
self.lengths = [len(i) for i in self.feats]
self.feats = self.zero_pad_and_stack(self.feats)
self.label_transformer = LabelTransformer()
if isinstance(labels, (list, tuple)):
self.labels = np.array(self.label_transformer.fit_transform(labels)).astype('int64')
def read_spectrogram(self, spectrogram_file):
# Return transposed fused spectrogram file
raise NotImplementedError
def get_files_labels(self, txt):
# Read the [train|test]_labels.txt and return a list of files and a list of their labels
raise NotImplementedError
def zero_pad_and_stack(self, x):
# Pad Features. You can use code from Lab 2
raise NotImplementedError
def __getitem__(self, item):
return self.feats[item], self.labels[item], self.lengths[item]
def __len__(self):
return len(self.labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment