Created
October 8, 2017 10:22
-
-
Save kdubovikov/d4e5c688fa771227fdf8c924196a59fe to your computer and use it in GitHub Desktop.
Fast random subset sampling with Cython
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%cython | |
import numpy as np | |
cimport numpy as np | |
cimport cython # so we can use cython decorators | |
from cpython cimport bool # type annotation for boolean | |
# disable index bounds checking and negative indexing for speedups | |
@cython.wraparound(False) | |
@cython.boundscheck(False) | |
cdef cython_get_sample(np.ndarray arr, arr_len, n_iter, int sample_size, | |
bool fast): | |
cdef int start_idx | |
if fast: | |
start_idx = (n_iter * sample_size) % arr_len | |
if start_idx + sample_size >= arr_len: | |
np.random.shuffle(arr) | |
return arr[start_idx:start_idx+sample_size] | |
else: | |
return np.random.choice(arr, sample_size, replace=False) | |
@cython.wraparound(False) | |
@cython.boundscheck(False) | |
def cython_collect_samples(np.ndarray arr, | |
int sample_size, | |
int n_samples, | |
bool fast=False): | |
cdef np.ndarray samples | |
cdef int arr_len | |
cdef int sample_len | |
cdef np.ndarray sample | |
samples = np.zeros((n_samples + 1, sample_size), np.int64) # allocate all memory in advance | |
arr_len = len(arr) | |
for sample_n in range(0, n_samples): | |
sample = cython_get_sample(arr, arr_len, sample_n, | |
sample_size, | |
fast) | |
samples[sample_n] = sample | |
return samples |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment