Created
March 9, 2019 23:05
-
-
Save YannDubs/3550259636987a7b460a200efbd6acf3 to your computer and use it in GitHub Desktop.
Stratify sampling using numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def stratify_sampling(x, n_samples, stratify): | |
"""Perform stratify sampling of a tensor. | |
parameters | |
---------- | |
x: np.ndarray or torch.Tensor | |
Array to sample from. Sampels from first dimension. | |
n_samples: int | |
Number of samples to sample | |
stratify: tuple of int | |
Size of each subgroup. Note that the sum of all the sizes | |
need to be equal to `x.shape[']`. | |
""" | |
n_total = x.shape[0] | |
assert sum(stratify) == n_total | |
n_strat_samples = [int(i*n_samples/n_total) for i in stratify] | |
cum_n_samples = np.cumsum([0]+list(stratify)) | |
sampled_idcs = [] | |
for i, n_strat_sample in enumerate(n_strat_samples): | |
sampled_idcs.append(np.random.choice(range(cum_n_samples[i], cum_n_samples[i+1]), | |
replace=False, | |
size=n_strat_sample)) | |
# might not be correct number of samples due to rounding | |
n_current_samples = sum(n_strat_samples) | |
if n_current_samples < n_samples: | |
delta_n_samples = n_samples - n_current_samples | |
# might actually resample same as before, but it's only for a few | |
sampled_idcs.append(np.random.choice(range(n_total), replace=False, size=delta_n_samples)) | |
samples = x[np.concatenate(sampled_idcs), ...] | |
return samples |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment