Created
October 31, 2017 17:37
-
-
Save ozancaglayan/8ff67259f73e18be078177f1ae5a7805 to your computer and use it in GitHub Desktop.
WeightedBatchSampler for PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class WeightedBatchSampler(Sampler): | |
def __init__(self, n_elems, batch_size, | |
initial_p=None, epoch_p_reset=False): | |
self.n_elems = n_elems | |
self.batch_size = batch_size | |
self.epoch_p_reset = epoch_p_reset | |
self.n_batches = math.ceil(self.n_elems / self.batch_size) | |
if initial_p is None: | |
# Start with uniform probability for each sample | |
self.p = np.ones(self.n_elems) / self.n_elems | |
else: | |
self.p = p | |
def update_p(self, sample_idxs, sample_scores): | |
# Cumulatively update scores for samples | |
self.p[sample_idxs] += sample_scores | |
# Normalize probability distribution | |
self.p /= self.p.sum() | |
def __iter__(self): | |
ctr = 0 | |
while ctr < self.n_batches: | |
bidxs = np.random.choice(self.n_elems, self.batch_size, | |
replace=False, p=self.p) | |
scores = (yield bidxs) | |
if scores is not None: | |
print(' Received scores') | |
self.update_p(bidxs, scores) | |
# Increment batch counter | |
ctr += 1 | |
if self.epoch_p_reset: | |
self.p = np.ones(self.n_elems) / self.n_elems | |
def __len__(self): | |
"""Returns how many batches are inside.""" | |
return self.n_batches | |
if __name__ == '__main__': | |
sampler = WeightedBatchSampler(100000, 32) | |
gen = iter(sampler) | |
for idx, batch in enumerate(gen): | |
# Flip a biased coin and update scores | |
if np.random.binomial(1, p=0.05): | |
print('Sending scores back to generator at iteration %d' % idx) | |
scores = np.random.randint(low=1, high=2, size=batch.size) | |
gen.send(scores) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment