Created
October 3, 2017 11:50
-
-
Save thomwolf/3d1b008336e7ec41a00ce723703ac843 to your computer and use it in GitHub Desktop.
A pyTorch BatchSampler that enables large epochs on small datasets and balanced sampling from unbalanced datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DeepMojiBatchSampler(object): | |
"""A Batch sampler that enables larger epochs on small datasets and | |
has upsampling functionality. | |
# Arguments: | |
y_in: Labels of the dataset. | |
batch_size: Batch size. | |
epoch_size: Number of samples in an epoch. | |
upsample: Whether upsampling should be done. This flag should only be | |
set on binary class problems. | |
seed: Random number generator seed. | |
# __iter__ output: | |
iterator of lists (batches) of indices in the dataset | |
""" | |
def __init__(self, y_in, batch_size, epoch_size, upsample, seed): | |
self.batch_size = batch_size | |
self.epoch_size = epoch_size | |
self.upsample = upsample | |
np.random.seed(seed) | |
if upsample: | |
# Should only be used on binary class problems | |
assert len(y_in.shape) == 1 | |
neg = np.where(y_in.numpy() == 0)[0] | |
pos = np.where(y_in.numpy() == 1)[0] | |
assert epoch_size % 2 == 0 | |
samples_pr_class = int(epoch_size / 2) | |
else: | |
ind = range(len(y_in)) | |
if not upsample: | |
# Randomly sample observations in a balanced way | |
self.sample_ind = np.random.choice(ind, epoch_size, replace=True) | |
else: | |
# Randomly sample observations in a balanced way | |
sample_neg = np.random.choice(neg, samples_pr_class, replace=True) | |
sample_pos = np.random.choice(pos, samples_pr_class, replace=True) | |
concat_ind = np.concatenate((sample_neg, sample_pos), axis=0) | |
# Shuffle to avoid labels being in specific order | |
# (all negative then positive) | |
p = np.random.permutation(len(concat_ind)) | |
self.sample_ind = concat_ind[p] | |
label_dist = np.mean(y_in.numpy()[self.sample_ind]) | |
assert(label_dist > 0.45) | |
assert(label_dist < 0.55) | |
def __iter__(self): | |
# Hand-off data using batch_size | |
for i in range(int(self.epoch_size/self.batch_size)): | |
start = i * self.batch_size | |
end = min(start + self.batch_size, self.epoch_size) | |
yield self.sample_ind[start:end] | |
def __len__(self): | |
# Take care of the last (maybe incomplete) batch | |
return (self.epoch_size + self.batch_size - 1) // self.batch_size |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment