Created
December 20, 2018 15:25
-
-
Save georgepar/8a8f75c87732159b3699fb46d9523d78 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import Dataset | |
from sklearn.preprocessing import LabelEncoder | |
class LabelTransformer(LabelEncoder): | |
def inverse(self, y): | |
try: | |
return super(LabelTransformer, self).inverse_transform(y) | |
except: | |
return super(LabelTransformer, self).inverse_transform([y]) | |
def transform(self, y): | |
try: | |
return super(LabelTransformer, self).transform(y) | |
except: | |
return super(LabelTransformer, self).transform([y]) | |
class SpectrogramDataset(Dataset): | |
def __init__(self, path, train=True): | |
t = 'train' if train else 'test' | |
p = os.path.join(path, t) | |
self.index = os.path.join(path, "{}_labels.txt".format(t)) | |
self.files, labels = self.get_files_labels(self.index) | |
self.feats = [self.read_spectrogram(os.path.join(p, f)) for f in self.files] | |
self.feat_dim = self.feats[0].shape[1] | |
self.lengths = [len(i) for i in self.feats] | |
self.feats = self.zero_pad_and_stack(self.feats) | |
self.label_transformer = LabelTransformer() | |
if isinstance(labels, (list, tuple)): | |
self.labels = np.array(self.label_transformer.fit_transform(labels)).astype('int64') | |
def read_spectrogram(self, spectrogram_file): | |
# Return transposed fused spectrogram file | |
raise NotImplementedError | |
def get_files_labels(self, txt): | |
# Read the [train|test]_labels.txt and return a list of files and a list of their labels | |
raise NotImplementedError | |
def zero_pad_and_stack(self, x): | |
# Pad Features. You can use code from Lab 2 | |
raise NotImplementedError | |
def __getitem__(self, item): | |
return self.feats[item], self.labels[item], self.lengths[item] | |
def __len__(self): | |
return len(self.labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment