Skip to content

Instantly share code, notes, and snippets.

Last active July 13, 2019 20:28
Show Gist options
  • Save johnmeade/1ce44243afd5ab4a595da7c4aa129e19 to your computer and use it in GitHub Desktop.
Save johnmeade/1ce44243afd5ab4a595da7c4aa129e19 to your computer and use it in GitHub Desktop.
Fast sampling from arbitrary probability densities in Python
Tools for sampling from arbitrary probability densities.
pip install scipy numpy
John Meade 2019
MIT license
import numpy as np
import scipy.interpolate
from random import random
from multiprocessing import cpu_count, Pool
def pdf2cdf(pdf, xmin, xmax, resolution=100):
Create an approximate CDF function via numercal integration of a PDF.
pdf (callable): vectorized callable PDF function
xmin (float): left boundary of the approximation domain
xmax (float): right boundary of the approximation domain
resolution (int): accuracy of the approximation
cdf (callable): CDF
x = np.linspace(xmin, xmax, resolution)
y = pdf(x)
cs = np.cumsum(y)
# normalize due to boundaries
cs -= cs.min()
cs /= cs.max()
cdf = scipy.interpolate.interp1d(x, cs, kind='cubic', assume_sorted=True)
return cdf
def flatten(xs):
'One-level array flatten operation.'
y = []
for x in xs:
y += x
return y
def chunks(tot, n):
Split a number into `n` chunks, as equally as possible.
tot: the number that will be split into chunks
n: the number of chunks to split into
>>> chunks(53, n=8)
[7, 6, 7, 6, 7, 7, 6, 7]
>>> sum(chunks(53, n=8)) == 53
chunks (list): the chunks
delta = tot / n
intrange = [ int(round(i * delta)) for i in range(n+1) ]
pairs = zip(intrange, intrange[1:])
return list(map(lambda x: x[1]-x[0], pairs))
def _sample(args):
Worker function to perform lookup-sampling.
args (tuple): contains the lookup table, the CDF
range, and the number of samples to look up.
samples (list): the samples
lookup, rnge, n = args
samps = []
for _ in range(n):
x = random()
for i, r in enumerate(rnge):
if x <= r:
samps.append( lookup[ rnge[ i ] ] )
return samps
class DistSampler:
Approximate distribution sampling via CDF inversion using lookups.
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
vmcdf = lambda x: scipy.stats.vonmises.cdf(x, kappa=kappa)
vm = DistSampler(vmcdf, xmin=-np.pi, xmax=np.pi)
plt.hist(vm.sample(k=1e4), bins=16)
def __init__(self, cdf, xmin, xmax, resolution=100):
cdf (callable): vectorized callable CDF function
xmin (float): left boundary of the approximation domain
xmax (float): right boundary of the approximation domain
resolution (int): accuracy of the approximation (this
has a big impact on sampling speed)
domain = np.linspace(xmin, xmax, resolution)
self.rnge = cdf(domain)
self.lookup = { r: x for r, x in zip(self.rnge, domain) }
self.procs = 2 * cpu_count()
self.pool = Pool(self.procs)
def sample(self, k=1):
Draw samples from the distribution. Uses multiprocessing for speed.
k (int): Number of samples to draw. Use this instead of calling
this method with `k=1` repeatedly!
samples (list): the approximate samples.
k = int(k)
chnks = chunks(k, n=2*self.procs)
args = [ (self.lookup, self.rnge, n) for n in chnks ]
res = self.pool.map_async(_sample, args)
samps = flatten(res.get())
return samps
if __name__ == '__main__':
import matplotlib.pyplot as plt
# Eg: Von Mises Distribution
print('Von Mises Distribution')
mu = 0
kappa = 1.25
vmpdf = lambda x: np.exp(kappa * np.cos(x - mu)) / (2 * np.pi * np.i0(kappa))
vmcdf = pdf2cdf(vmpdf, xmin=-np.pi, xmax=np.pi)
vm = DistSampler(vmcdf, xmin=-np.pi, xmax=np.pi)
plt.hist(vm.sample(k=1e5), bins=16)
# %timeit vm.sample(k=1e4)
# 83.4 ms ± 17.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# => about 8us per sample
# Eg: Normal Distribution
print('Normal Distribution')
mu = 0
sigma = 1
norm_pdf = lambda x: np.sqrt(2 * np.pi * sigma**2)**(-1) * np.exp(-(x - mu)**2 / (2 * sigma**2))
norm_cdf = pdf2cdf(norm_pdf, xmin=-5, xmax=5)
norm = DistSampler(norm_cdf, xmin=-5, xmax=5)
plt.hist(norm.sample(k=1e5), bins=16)
# %timeit norm.sample(k=1e4)
# 80.3 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# => about 8us per sample
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment