Last active
July 13, 2019 20:28
-
-
Save johnmeade/1ce44243afd5ab4a595da7c4aa129e19 to your computer and use it in GitHub Desktop.
Fast sampling from arbitrary probability densities in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Tools for sampling from arbitrary probability densities. | |
Requirements: | |
pip install scipy numpy | |
John Meade 2019 | |
MIT license | |
''' | |
import numpy as np | |
import scipy.interpolate | |
from random import random | |
from multiprocessing import cpu_count, Pool | |
def pdf2cdf(pdf, xmin, xmax, resolution=100): | |
''' | |
Create an approximate CDF function via numercal integration of a PDF. | |
Args: | |
pdf (callable): vectorized callable PDF function | |
xmin (float): left boundary of the approximation domain | |
xmax (float): right boundary of the approximation domain | |
resolution (int): accuracy of the approximation | |
Returns: | |
cdf (callable): CDF | |
''' | |
x = np.linspace(xmin, xmax, resolution) | |
y = pdf(x) | |
cs = np.cumsum(y) | |
# normalize due to boundaries | |
cs -= cs.min() | |
cs /= cs.max() | |
cdf = scipy.interpolate.interp1d(x, cs, kind='cubic', assume_sorted=True) | |
return cdf | |
def flatten(xs): | |
'One-level array flatten operation.' | |
y = [] | |
for x in xs: | |
y += x | |
return y | |
def chunks(tot, n): | |
''' | |
Split a number into `n` chunks, as equally as possible. | |
Args: | |
tot: the number that will be split into chunks | |
n: the number of chunks to split into | |
Example: | |
>>> chunks(53, n=8) | |
[7, 6, 7, 6, 7, 7, 6, 7] | |
>>> sum(chunks(53, n=8)) == 53 | |
True | |
Returns: | |
chunks (list): the chunks | |
''' | |
delta = tot / n | |
intrange = [ int(round(i * delta)) for i in range(n+1) ] | |
pairs = zip(intrange, intrange[1:]) | |
return list(map(lambda x: x[1]-x[0], pairs)) | |
def _sample(args): | |
''' | |
Worker function to perform lookup-sampling. | |
Args: | |
args (tuple): contains the lookup table, the CDF | |
range, and the number of samples to look up. | |
Returns: | |
samples (list): the samples | |
''' | |
lookup, rnge, n = args | |
samps = [] | |
for _ in range(n): | |
x = random() | |
for i, r in enumerate(rnge): | |
if x <= r: | |
break | |
samps.append( lookup[ rnge[ i ] ] ) | |
return samps | |
class DistSampler: | |
''' | |
Approximate distribution sampling via CDF inversion using lookups. | |
Example: | |
import numpy as np | |
import scipy.stats | |
import matplotlib.pyplot as plt | |
vmcdf = lambda x: scipy.stats.vonmises.cdf(x, kappa=kappa) | |
vm = DistSampler(vmcdf, xmin=-np.pi, xmax=np.pi) | |
plt.hist(vm.sample(k=1e4), bins=16) | |
plt.show() | |
''' | |
def __init__(self, cdf, xmin, xmax, resolution=100): | |
''' | |
Args: | |
cdf (callable): vectorized callable CDF function | |
xmin (float): left boundary of the approximation domain | |
xmax (float): right boundary of the approximation domain | |
resolution (int): accuracy of the approximation (this | |
has a big impact on sampling speed) | |
''' | |
domain = np.linspace(xmin, xmax, resolution) | |
self.rnge = cdf(domain) | |
self.lookup = { r: x for r, x in zip(self.rnge, domain) } | |
self.procs = 2 * cpu_count() | |
self.pool = Pool(self.procs) | |
def sample(self, k=1): | |
''' | |
Draw samples from the distribution. Uses multiprocessing for speed. | |
Args: | |
k (int): Number of samples to draw. Use this instead of calling | |
this method with `k=1` repeatedly! | |
Returns: | |
samples (list): the approximate samples. | |
''' | |
k = int(k) | |
chnks = chunks(k, n=2*self.procs) | |
args = [ (self.lookup, self.rnge, n) for n in chnks ] | |
res = self.pool.map_async(_sample, args) | |
samps = flatten(res.get()) | |
return samps | |
if __name__ == '__main__': | |
import matplotlib.pyplot as plt | |
# | |
# Eg: Von Mises Distribution | |
# | |
print('Von Mises Distribution') | |
mu = 0 | |
kappa = 1.25 | |
vmpdf = lambda x: np.exp(kappa * np.cos(x - mu)) / (2 * np.pi * np.i0(kappa)) | |
vmcdf = pdf2cdf(vmpdf, xmin=-np.pi, xmax=np.pi) | |
vm = DistSampler(vmcdf, xmin=-np.pi, xmax=np.pi) | |
plt.hist(vm.sample(k=1e5), bins=16) | |
plt.show() | |
# %timeit vm.sample(k=1e4) | |
# 83.4 ms ± 17.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) | |
# => about 8us per sample | |
# | |
# Eg: Normal Distribution | |
# | |
print('Normal Distribution') | |
mu = 0 | |
sigma = 1 | |
norm_pdf = lambda x: np.sqrt(2 * np.pi * sigma**2)**(-1) * np.exp(-(x - mu)**2 / (2 * sigma**2)) | |
norm_cdf = pdf2cdf(norm_pdf, xmin=-5, xmax=5) | |
norm = DistSampler(norm_cdf, xmin=-5, xmax=5) | |
plt.hist(norm.sample(k=1e5), bins=16) | |
plt.show() | |
# %timeit norm.sample(k=1e4) | |
# 80.3 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) | |
# => about 8us per sample |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment