Skip to content

Instantly share code, notes, and snippets.

@RicherMans
Last active November 5, 2019 09:21
Show Gist options
  • Save RicherMans/57cce4d0666ce32736a1f65abed25299 to your computer and use it in GitHub Desktop.
Save RicherMans/57cce4d0666ce32736a1f65abed25299 to your computer and use it in GitHub Desktop.
Minimum Occupancy Sampler for SED
import torch
import torch.utils.data as dset
import torch
import random
import sklearn.datasets as skl
import numpy as np
import itertools
class MinimumOccupancySampler(dset.Sampler):
"""
docstring for MinimumOccupancySampler
samples at least one instance from each class sequentially
"""
def __init__(self, labels, sampling_factor=1, random_state=None):
self.labels = labels
n_samples, n_labels = labels.shape
self.label_to_idx_list = []
self.random_state = np.random.RandomState(seed=random_state)
for lb_idx in range(n_labels):
label_indexes = np.where(labels[:, lb_idx] == 1)[0]
self.random_state.shuffle(label_indexes)
self.label_to_idx_list.append(label_indexes)
data_source = []
for _ in range(sampling_factor):
self.random_state.shuffle(self.label_to_idx_list)
for indexes in itertools.zip_longest(*self.label_to_idx_list):
indexes = np.array(indexes)
to_pad_indexes = np.where(indexes == None)[0]
for idx in to_pad_indexes:
indexes[idx] = random.choice(self.label_to_idx_list[idx])
data_source.append(indexes)
self.data_source = np.array(data_source)
def __iter__(self):
N_samples = len(self.data_source)
rand_indices = self.random_state.permutation(N_samples)
return iter(np.concatenate(self.data_source[rand_indices], axis=0))
def __len__(self):
return len(self.data_source)
X, Y = skl.make_multilabel_classification(n_samples=100,
n_features=10,
n_classes=10,
random_state=0)
X = torch.as_tensor(X)
Y = torch.as_tensor(Y)
tsd = dset.TensorDataset(X, Y)
sampler = MinimumOccupancySampler(Y, sampling_factor=2, random_state=0)
for id, (x, y) in enumerate(
dset.DataLoader(dataset=tsd, sampler=sampler, batch_size=20)):
print(id, y.sum(0))
for id, (x, y) in enumerate(
dset.DataLoader(
dataset=dset.TensorDataset(X, Y),
sampler=sampler,
batch_size=20,
)):
print(id, y.sum(0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment