Skip to content

Instantly share code, notes, and snippets.

@ncullen93
Created April 27, 2017 01:22
Show Gist options
  • Save ncullen93/f2825de65a1e8cb5a256355cb1ac5314 to your computer and use it in GitHub Desktop.
Save ncullen93/f2825de65a1e8cb5a256355cb1ac5314 to your computer and use it in GitHub Desktop.
MultiSampler for Pytorch
class MultiSampler(Sampler):
"""Samples elements more than once in a single pass through the data.
This allows the number of samples per epoch to be larger than the number
of samples itself, which can be useful for data augmentation.
"""
def __init__(self, nb_samples, desired_samples, shuffle=False):
self.data_samples = nb_samples
self.desired_samples = desired_samples
self.shuffle
def gen_sample_array(self):
n_repeats = self.desired_samples / self.data_samples
self.sample_idx_array = torch.range(0,self.data_samples-1).repeat(n_repeat).long()
if self.shuffle:
self.sample_idx_array = self.sample_idx_array[torch.randperm(len(self.sample_idx_array)]
return self.sample_idx_array
def __iter__(self):
return iter(self.gen_sample_array())
def __len__(self):
return self.desired_samples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment