Created
April 27, 2017 01:22
-
-
Save ncullen93/f2825de65a1e8cb5a256355cb1ac5314 to your computer and use it in GitHub Desktop.
MultiSampler for Pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MultiSampler(Sampler): | |
"""Samples elements more than once in a single pass through the data. | |
This allows the number of samples per epoch to be larger than the number | |
of samples itself, which can be useful for data augmentation. | |
""" | |
def __init__(self, nb_samples, desired_samples, shuffle=False): | |
self.data_samples = nb_samples | |
self.desired_samples = desired_samples | |
self.shuffle | |
def gen_sample_array(self): | |
n_repeats = self.desired_samples / self.data_samples | |
self.sample_idx_array = torch.range(0,self.data_samples-1).repeat(n_repeat).long() | |
if self.shuffle: | |
self.sample_idx_array = self.sample_idx_array[torch.randperm(len(self.sample_idx_array)] | |
return self.sample_idx_array | |
def __iter__(self): | |
return iter(self.gen_sample_array()) | |
def __len__(self): | |
return self.desired_samples |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment