Skip to content

Instantly share code, notes, and snippets.

@awoimbee
Created May 15, 2020 17:32
Show Gist options
  • Save awoimbee/92eff12f6f8d58b3e45ef1f7cd94df57 to your computer and use it in GitHub Desktop.
Save awoimbee/92eff12f6f8d58b3e45ef1f7cd94df57 to your computer and use it in GitHub Desktop.
poisson random numbers with numba
# from https://scicomp.stackexchange.com/questions/27330/how-to-generate-poisson-distributed-random-numbers-quickly-and-accurately
# cupy is much, much faster than this
import numpy as np
from numba import jit
from math import *
import random
@jit()
def poissrnd(mean: float) -> int:
if mean < 60:
L = exp(-mean)
result = 0
p = random.random()
while (p > L):
result += 1
p *= random.random()
return result
sqrt_mean = sqrt(mean)
log_mean = log(mean)
while True:
while True:
x = mean + sqrt_mean * tan(pi * (random.random() - 1 / 2.0))
if not (x < 0):
break
g_x = sqrt_mean / (pi * ((x - mean) * (x - mean) + mean))
m = floor(x)
f_m = exp(m * log_mean - mean - lgamma(m + 1))
r = f_m / g_x / 2.4
if not (random.random() > r):
break
return int(m)
@jit()
def poissrnd_bis(mean: float) -> int:
if (mean < 12.0):
# Use direct method
g = exp(-mean)
res = 0
t = random.random()
while t > g:
res += 1
t *= random.random()
return res
# Use rejection method
sq_mean = sqrt(2.0 * mean)
lg_mean = log(mean)
g = mean * lg_mean - lgamma(mean + 1.0)
t = 0.
while (random.random() > t):
res = -1
while res < 0.0:
y = tan(pi * random.random())
res = sq_mean * y + mean
res = floor(res)
# The factor 0.9 is chosen so that t never exceeds 1.
t = 0.9 * (1.0 + y * y) * exp(res * lg_mean - lgamma(res + 1.0) - g)
return res
@jit()
def poissrnd_array(array):
arr_sum = array.sum()
arr_len = len(array)
results = np.empty_like(array, dtype=np.int64)
for idx, item in enumerate(array):
results[idx] = poissrnd_bis(arr_len * item / arr_sum)
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment