Created
September 6, 2017 21:21
-
-
Save michaelchughes/bfa97a3599ba5de0be96f005d96244e3 to your computer and use it in GitHub Desktop.
Computing marginal log likelihood of data for unknown mean, fixed variance normal model.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Computing marginal log likelihood of data for unknown mean, fixed variance model. | |
Running Demo | |
------------ | |
$ python marg_lik_of_normal.py | |
This will compare numerical and exact methods for calculating marg. lik. | |
By visual inspection, user can see the exact method matches numerical. | |
Demo uses prior with mean=0, var=1. | |
Demo uses likelihood variance = 0.337 | |
Some random datasets "x" of size N=3 and N=30 are sampled from standard normal | |
Expected Output | |
--------------- | |
N = 1 | x = -1.00 | numer -1.43812e+00 | exact -1.43812e+00 | |
N = 1 | x = -0.50 | numer -1.15765e+00 | exact -1.15765e+00 | |
N = 1 | x = 0.00 | numer -1.06415e+00 | exact -1.06415e+00 | |
N = 1 | x = 0.50 | numer -1.15765e+00 | exact -1.15765e+00 | |
N = 1 | x = 1.00 | numer -1.43812e+00 | exact -1.43812e+00 | |
N = 3 | x = 1.76 0.40 0.98 | numer -4.15559e+00 | exact -4.15559e+00 | |
N = 3 | x = 2.24 1.87 -0.98 | numer -1.19548e+01 | exact -1.19548e+01 | |
N = 3 | x = 0.95 -0.15 -0.10 | numer -3.44567e+00 | exact -3.44567e+00 | |
N = 3 | x = 0.41 0.14 1.45 | numer -3.89610e+00 | exact -3.89610e+00 | |
N = 30 | x = 0.76 0.12 0.44... | numer -6.59719e+01 | exact -6.59719e+01 | |
N = 30 | x = -1.71 1.95 -0.51... | numer -4.39386e+01 | exact -4.39386e+01 | |
N = 30 | x = 1.14 -1.23 0.40... | numer -5.70260e+01 | exact -5.70260e+01 | |
N = 30 | x = -1.27 0.97 -1.17... | numer -5.82784e+01 | exact -5.82784e+01 | |
Using the functions | |
------------------ | |
See calc_log_pdf_x__exact to compute the exact marginal log likelihood. | |
Notation | |
-------- | |
Sample the mean: | |
u ~ Normal( mean=m_prior, var=v_prior) | |
Sample each data index i independently: | |
x_i ~ Normal( mean=u, var=v_lik) | |
We represent normal parameters via a dictionary with keys: | |
* 'v' for variance | |
* 'm' for mean (omit if unknown) | |
''' | |
import numpy as np | |
import scipy.stats | |
def calc_suff_stats(x_N): | |
''' Compute sufficient statistics for 1D dataset | |
Returns | |
------- | |
SS : dict with keys | |
* nx : number of points in the dataset | |
* sumx : sum of all points in dataset | |
* sumxx : sum of squares of points in dataset | |
''' | |
SS = dict( | |
nx=x_N.size, | |
sumx=np.sum(x_N), | |
sumxx=np.sum(np.square(x_N)), | |
) | |
SS['barx'] = SS['sumx'] / (1e-100 + SS['nx']) | |
return SS | |
def calc_param_dict_for_post_of_mu_given_data(x_N, P_lik, P_prior): | |
''' Calculate posterior parameters for p(mu | data) | |
Returns | |
------- | |
param_dict : dict with keys | |
* m : mean, float | |
* v : variance, float | |
''' | |
SS_data = calc_suff_stats(x_N) | |
v = P_lik['v'] * P_prior['v'] / (SS_data['nx'] * P_prior['v'] + P_lik['v']) | |
m = v * (P_prior['m'] / P_prior['v'] + SS_data['sumx'] / P_lik['v']) | |
return dict(m=m, v=v) | |
def calc_log_pdf_x__pred_post(x_test, x_N, P_lik, P_prior): | |
''' Calculate log p(x_test | x_N) under Gaussian model | |
Returns | |
------- | |
logpdf : 1D array, size of x_test | |
''' | |
P_post = calc_param_dict_for_post_of_mu_given_data(x_N, P_lik, P_prior) | |
m = P_post['m'] | |
v = P_post['v'] + P_lik['v'] | |
return scipy.stats.norm.logpdf(x_test, m, np.sqrt(v)) | |
def calc_log_pdf_x__exact( | |
x_N, P_lik, P_prior): | |
''' Compute marginal likelihood via exact closed-form formula | |
Evaluates log \int_u p(x,u) exactly | |
Returns | |
------- | |
logpdf : scalar float | |
''' | |
SS = calc_suff_stats(x_N) | |
P_post = calc_param_dict_for_post_of_mu_given_data(x_N, P_lik, P_prior) | |
denom = P_lik['v'] * P_prior['v'] | |
log_pdf = ( | |
- 0.5 * SS['nx'] * np.log(2 * np.pi) | |
- 0.5 * np.log(P_prior['v']) | |
+ 0.5 * np.log(P_post['v']) | |
- 0.5 * SS['nx'] * np.log(P_lik['v']) | |
- 0.5 * (SS['sumxx'] - SS['nx'] * np.square(SS['barx'])) / P_lik['v'] | |
- 0.5 / denom * SS['nx'] * P_prior['v'] * np.square(SS['barx']) | |
- 0.5 / denom * P_lik['v'] * np.square(P_prior['m']) | |
+ 0.5 * np.square(P_post['m']) / P_post['v'] | |
) | |
return log_pdf | |
def calc_log_pdf_x__numeric( | |
x_N, P_lik, P_prior, | |
u_grid=np.linspace(-6, 6, 250000), | |
): | |
''' Compute marginal likelihood via numerical integration | |
Evaluates log \int_u p(x,u) at fixed grid of u values | |
Returns | |
------- | |
logpdf : scalar float | |
''' | |
prior_grid = scipy.stats.norm.logpdf( | |
u_grid, P_prior['m'], np.sqrt(P_prior['v'])) | |
lik_grid = np.sum( | |
scipy.stats.norm.logpdf( | |
x_N[:,np.newaxis], | |
u_grid[np.newaxis,:], np.sqrt(P_lik['v'])), | |
axis=0) | |
assert prior_grid.size == lik_grid.size | |
post_grid = prior_grid + lik_grid | |
return np.log(np.trapz(np.exp(post_grid), u_grid)) | |
def pprint_log_pdf__numeric_vs_exact(x_N, P_lik, P_prior): | |
''' Pretty print comparison of numerical and exact method on one line | |
Returns | |
------- | |
Nothing | |
''' | |
print " N = %2d | x = %-21s | numer % .5e | exact % .5e" % ( | |
x_N.size, | |
' '.join(['% 5.2f' % x for x in x_N[:3]]) + '...' * (x_N.size > 3), | |
calc_log_pdf_x__numeric(x_N, P_lik, P_prior), | |
calc_log_pdf_x__exact(x_N, P_lik, P_prior), | |
) | |
if __name__ == '__main__': | |
prng = np.random.RandomState(0) | |
# Prior on mu: mean 0, variance 1 | |
P_prior = dict(m=0, v=1) | |
# Likelihood : gaussian with FIXED variance v, mean unknown | |
P_lik = dict(v=0.337) | |
# Single entry dataset | |
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]: | |
x_N = np.asarray([x]) | |
pprint_log_pdf__numeric_vs_exact(x_N, P_lik, P_prior) | |
# N = 3 dataset | |
for trial in range(4): | |
x_N = prng.randn(3) | |
pprint_log_pdf__numeric_vs_exact(x_N, P_lik, P_prior) | |
# N = 30 dataset | |
for trial in range(4): | |
x_N = prng.randn(30) | |
pprint_log_pdf__numeric_vs_exact(x_N, P_lik, P_prior) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment