Last active
April 22, 2021 11:14
-
-
Save dbalabka/9c0a3d88312f6d51af8949e85758aa7a to your computer and use it in GitHub Desktop.
scikits-bootstrap Bootstrapping resampling with Numba and performance testing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scikits.bootstrap as bootstrap | |
import numpy as np | |
import time | |
import numba | |
@numba.njit(parallel=True, fastmath=True) | |
def _calculate_boostrap_mean_stat(data: np.ndarray, n_samples: int) -> np.ndarray: | |
n = data.shape[0] | |
stat = np.zeros(n_samples) | |
for i in numba.prange(n_samples): | |
stat[i] = np.random.choice(data, n).mean() | |
return stat | |
tdata = (np.random.randint(0, 5, 100_000), ) | |
n_samples = 10_000 | |
start = time.time() | |
bootindices = bootstrap.bootstrap_indices(tdata[0], n_samples) | |
stat_old = np.array([np.mean(*(x[indices] for x in tdata)) | |
for indices in bootindices]) | |
end = time.time() | |
print(end - start) | |
start = time.time() | |
rng = np.random.default_rng() | |
stat_new1 = np.array([np.mean(*(rng.choice(x, tdata[0].shape[0]) for x in tdata)) for _ in range(n_samples)]) | |
end = time.time() | |
print(end - start) | |
start = time.time() | |
stat_new2 = _calculate_boostrap_mean_stat(tdata[0], n_samples) | |
end = time.time() | |
print(end - start) | |
print(f'{stat_old.shape} == {stat_new1.shape}') | |
print(f'{stat_old.shape} == {stat_new2.shape}') | |
print(f'{stat_old.mean()} == {stat_new1.mean()}') | |
print(f'{stat_old.mean()} == {stat_new2.mean()}') | |
assert stat_old.shape == stat_new1.shape | |
assert stat_old.shape == stat_new2.shape | |
assert round(stat_old.mean(), 3) == round(stat_new1.mean(), 3) | |
assert round(stat_old.mean(), 3) == round(stat_new2.mean(), 3) | |
# Numba debug | |
# bootstrap._calculate_boostrap_mean_stat.parallel_diagnostics(level=4) | |
# bootstrap._calculate_boostrap_mean_stat.inspect_types() |
Author
dbalabka
commented
Apr 22, 2021
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment