Created
November 1, 2021 04:04
-
-
Save michaelchughes/278b4e70274ce778f1231b74701e1dce to your computer and use it in GitHub Desktop.
Demonstration of ELBO computation using Monte Carlo method
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' VI for Poisson Normal | |
Model | |
----- | |
Latent variable z is drawn from a Normal prior: z ~ Normal( 40, 10) | |
Data y is drawn iid from a Poisson likelihood: y_n ~ Poisson(z) | |
Approx Posterior | |
---------------- | |
Posterior on z is assumed to be Normal with unknown mean and stddev | |
''' | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
import scipy.stats | |
import jax | |
import jax.numpy as jnp | |
import jax.scipy.stats as jstats | |
def calc_ELBO(q, prior, data, random_state=None, n_mc_samples=100): | |
''' Estimate the ELBO objective via Monte Carlo samples | |
''' | |
S = n_mc_samples | |
N = data['y_N'].size | |
z_S = random_state.randn(S) * q['stddev'] + q['mean'] | |
log_prior_pdf_S = jstats.norm.logpdf( | |
z_S, prior['mean'], prior['stddev']) | |
log_q_pdf_S = jstats.norm.logpdf( | |
z_S, q['mean'], q['stddev']) | |
log_lik_pdf_NS = jstats.poisson.logpmf( | |
data['y_N'].reshape((N,1)), z_S.reshape((1,S))) | |
elbo_S = jnp.sum(log_lik_pdf_NS, axis=0) + log_prior_pdf_S - log_q_pdf_S | |
return jnp.mean(elbo_S) / N | |
if __name__ == '__main__': | |
n_mc_samples = 1000 | |
random_state = np.random.RandomState(0) | |
z_true = 50.0 | |
N = 100 | |
y_N = scipy.stats.poisson(z_true).rvs(N, random_state) | |
data = { | |
'y_N':y_N | |
} | |
prior = { | |
'mean': 40.0, | |
'stddev': 10.0, | |
} | |
# Try q where the mean is varied from far below to far above true value | |
m_list = list() | |
elbo_list = list() | |
for delta in [-10, -5, 0, 5, 10]: | |
q = { | |
'mean': z_true + delta, | |
'stddev': 0.001 | |
} | |
elbo = calc_ELBO(q, prior, data, random_state, n_mc_samples) | |
elbo_list.append(elbo) | |
m_list.append(q['mean']) | |
plt.plot(m_list, elbo_list, label='ELBO') | |
plt.plot(z_true * np.ones(2), [np.min(elbo_list), np.max(elbo_list)], '--', label='true z') | |
plt.xlabel('mean of q') | |
plt.legend() | |
plt.figure() | |
# Try q where the STDDEV is varied from far below to far above ideal value | |
n_reps = 5 | |
for rr in range(n_reps): | |
s_list = list() | |
elbo_list = list() | |
for stddev in [0.01, 0.03, 0.1, 0.3, 1, 3.0, 10.0, 30.]: | |
q = { | |
'mean': z_true, | |
'stddev': stddev, | |
} | |
elbo = calc_ELBO(q, prior, data, random_state, n_mc_samples) | |
elbo_list.append(elbo) | |
s_list.append(q['stddev']) | |
plt.plot(np.log10(s_list), elbo_list, label='rep %02d' % (rr+1)) | |
plt.xlabel('log stddev of q') | |
plt.ylabel('ELBO (estimated with %d samples)' % n_mc_samples) | |
plt.legend() | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Expected output plot