Created
May 15, 2020 17:32
-
-
Save awoimbee/92eff12f6f8d58b3e45ef1f7cd94df57 to your computer and use it in GitHub Desktop.
poisson random numbers with numba
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# from https://scicomp.stackexchange.com/questions/27330/how-to-generate-poisson-distributed-random-numbers-quickly-and-accurately | |
# cupy is much, much faster than this | |
import numpy as np | |
from numba import jit | |
from math import * | |
import random | |
@jit() | |
def poissrnd(mean: float) -> int: | |
if mean < 60: | |
L = exp(-mean) | |
result = 0 | |
p = random.random() | |
while (p > L): | |
result += 1 | |
p *= random.random() | |
return result | |
sqrt_mean = sqrt(mean) | |
log_mean = log(mean) | |
while True: | |
while True: | |
x = mean + sqrt_mean * tan(pi * (random.random() - 1 / 2.0)) | |
if not (x < 0): | |
break | |
g_x = sqrt_mean / (pi * ((x - mean) * (x - mean) + mean)) | |
m = floor(x) | |
f_m = exp(m * log_mean - mean - lgamma(m + 1)) | |
r = f_m / g_x / 2.4 | |
if not (random.random() > r): | |
break | |
return int(m) | |
@jit() | |
def poissrnd_bis(mean: float) -> int: | |
if (mean < 12.0): | |
# Use direct method | |
g = exp(-mean) | |
res = 0 | |
t = random.random() | |
while t > g: | |
res += 1 | |
t *= random.random() | |
return res | |
# Use rejection method | |
sq_mean = sqrt(2.0 * mean) | |
lg_mean = log(mean) | |
g = mean * lg_mean - lgamma(mean + 1.0) | |
t = 0. | |
while (random.random() > t): | |
res = -1 | |
while res < 0.0: | |
y = tan(pi * random.random()) | |
res = sq_mean * y + mean | |
res = floor(res) | |
# The factor 0.9 is chosen so that t never exceeds 1. | |
t = 0.9 * (1.0 + y * y) * exp(res * lg_mean - lgamma(res + 1.0) - g) | |
return res | |
@jit() | |
def poissrnd_array(array): | |
arr_sum = array.sum() | |
arr_len = len(array) | |
results = np.empty_like(array, dtype=np.int64) | |
for idx, item in enumerate(array): | |
results[idx] = poissrnd_bis(arr_len * item / arr_sum) | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment