Last active
November 5, 2019 09:21
-
-
Save RicherMans/57cce4d0666ce32736a1f65abed25299 to your computer and use it in GitHub Desktop.
Minimum Occupancy Sampler for SED
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.utils.data as dset | |
import torch | |
import random | |
import sklearn.datasets as skl | |
import numpy as np | |
import itertools | |
class MinimumOccupancySampler(dset.Sampler): | |
""" | |
docstring for MinimumOccupancySampler | |
samples at least one instance from each class sequentially | |
""" | |
def __init__(self, labels, sampling_factor=1, random_state=None): | |
self.labels = labels | |
n_samples, n_labels = labels.shape | |
self.label_to_idx_list = [] | |
self.random_state = np.random.RandomState(seed=random_state) | |
for lb_idx in range(n_labels): | |
label_indexes = np.where(labels[:, lb_idx] == 1)[0] | |
self.random_state.shuffle(label_indexes) | |
self.label_to_idx_list.append(label_indexes) | |
data_source = [] | |
for _ in range(sampling_factor): | |
self.random_state.shuffle(self.label_to_idx_list) | |
for indexes in itertools.zip_longest(*self.label_to_idx_list): | |
indexes = np.array(indexes) | |
to_pad_indexes = np.where(indexes == None)[0] | |
for idx in to_pad_indexes: | |
indexes[idx] = random.choice(self.label_to_idx_list[idx]) | |
data_source.append(indexes) | |
self.data_source = np.array(data_source) | |
def __iter__(self): | |
N_samples = len(self.data_source) | |
rand_indices = self.random_state.permutation(N_samples) | |
return iter(np.concatenate(self.data_source[rand_indices], axis=0)) | |
def __len__(self): | |
return len(self.data_source) | |
X, Y = skl.make_multilabel_classification(n_samples=100, | |
n_features=10, | |
n_classes=10, | |
random_state=0) | |
X = torch.as_tensor(X) | |
Y = torch.as_tensor(Y) | |
tsd = dset.TensorDataset(X, Y) | |
sampler = MinimumOccupancySampler(Y, sampling_factor=2, random_state=0) | |
for id, (x, y) in enumerate( | |
dset.DataLoader(dataset=tsd, sampler=sampler, batch_size=20)): | |
print(id, y.sum(0)) | |
for id, (x, y) in enumerate( | |
dset.DataLoader( | |
dataset=dset.TensorDataset(X, Y), | |
sampler=sampler, | |
batch_size=20, | |
)): | |
print(id, y.sum(0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment