Last active
June 26, 2024 18:25
-
-
Save vankesteren/1854cf92b6ec26a3bb9628048bb2b9b6 to your computer and use it in GitHub Desktop.
Comparing lintsampler to basic uniform importance sampling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Comparing lintsampler to basic uniform importance sampling | |
from scipy.stats import norm, uniform | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from lintsampler import LintSampler | |
NSAMPLES = 1000000 | |
# GMM example | |
def gmm_pdf(x): | |
mu = np.array([-3.0, 0.5, 2.5]) | |
sig = np.array([1.0, 0.25, 0.75]) | |
w = np.array([0.4, 0.25, 0.35]) | |
return np.sum([w[i] * norm.pdf(x, mu[i], sig[i]) for i in range(3)], axis=0) | |
# importance sampling | |
rng = np.random.default_rng(42) | |
propdist = uniform(-12, 24) | |
proposals = propdist.rvs(NSAMPLES, random_state=rng) | |
weights = gmm_pdf(proposals) # / propdist.pdf(proposals) not needed because uniform! | |
importance_samples = np.random.choice(proposals, NSAMPLES, p=weights / weights.sum()) | |
# plot | |
bins = np.linspace(-12, 12, 200) | |
plt.hist(importance_samples, bins=bins, density=True, label="Samples", fc="goldenrod") | |
plt.plot(bins, gmm_pdf(bins), label="True PDF", c='teal') | |
plt.show() | |
# compare to lintsamples | |
rng = np.random.default_rng(42) | |
fixedgrid = np.linspace(-12, 12, 33) | |
lintsampler_samples = LintSampler(fixedgrid,pdf=gmm_pdf,vectorizedpdf=True,seed=rng).sample(N=NSAMPLES) | |
plt.hist(lintsampler_samples, bins=bins, density=True, label="Samples", fc="goldenrod") | |
plt.plot(bins, gmm_pdf(bins), label="True PDF", c='teal') | |
plt.show() | |
# compare log-likelihood | |
np.log(gmm_pdf(lintsampler_samples)).sum() | |
np.log(gmm_pdf(importance_samples)).sum() # higher log-likelihood! | |
# Doughnut example | |
def circles_pdf(x): | |
c1 = np.array([-2.0, -2.0]) | |
r1 = 1.0 | |
c2 = np.array([2.0, 2.0]) | |
r2 = 1.0 | |
w = 0.4 | |
v1 = x - c1 | |
v2 = x - c2 | |
av1 = np.linalg.norm(v1, axis=-1)[:, None] | |
av2 = np.linalg.norm(v2, axis=-1)[:, None] | |
pt1 = np.zeros_like(x) | |
pt2 = np.zeros_like(x) | |
m1 = (av1 == 0).squeeze() | |
m2 = (av2 == 0).squeeze() | |
pt1[~m1] = c1 + r1 * v1[~m1] / av1[~m1] | |
pt2[~m2] = c2 + r2 * v2[~m2] / av2[~m2] | |
pt1[m1] = c1 + r1 * np.array([1.0, 0.0]) | |
pt2[m2] = c2 + r2 * np.array([1.0, 0.0]) | |
d1 = np.linalg.norm(x - pt1, axis=-1) | |
d2 = np.linalg.norm(x - pt2, axis=-1) | |
return np.exp(-0.5 * d1**2 / w**2) + np.exp(-0.5 * d2**2 / w**2) | |
# importance sampling | |
rng = np.random.default_rng(42) | |
proposals = uniform(-4, 8).rvs(2*NSAMPLES, random_state=rng).reshape(NSAMPLES, 2) | |
imp_weights = circles_pdf(proposals) | |
idx = np.random.choice(NSAMPLES, NSAMPLES, p=imp_weights / imp_weights.sum()) | |
importance_samples = proposals[idx,:] | |
# visual | |
plt.hist2d(importance_samples[:,0], importance_samples[:,1], 128, [[-4, 4], [-4, 4]], cmap='inferno') | |
plt.show() | |
# compare to lintsamples | |
rng = np.random.default_rng(42) | |
N_grid = 128 | |
edges = np.linspace(-4, 4, N_grid + 1) | |
lintsampler_samples = LintSampler((edges,edges),pdf=circles_pdf,seed=rng,vectorizedpdf=True).sample(N=NSAMPLES) | |
# visual | |
plt.hist2d(lintsampler_samples[:,0], lintsampler_samples[:,1], 128, [[-4, 4], [-4, 4]], cmap='inferno') | |
plt.show() | |
# compare log-likelihood | |
np.log(circles_pdf(lintsampler_samples)).sum() | |
np.log(circles_pdf(importance_samples)).sum() # higher log-likelihood! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment