Created
May 1, 2020 13:13
-
-
Save vanAmsterdam/c87e4552892c6fdcc306282ab4948ad5 to your computer and use it in GitHub Desktop.
define and run a latent variable model, everything gaussian, with posterior predictive on new data for a description of the model see twitter thread: https://twitter.com/WvanAmsterdam/status/1251214875394740226?s=20
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
define and run a latent variable model, everything gaussian, with posterior predictive on newdata | |
for a description of the model see twitter thread: https://twitter.com/WvanAmsterdam/status/1251214875394740226?s=20 | |
DAG: | |
W1 <- U -> W2 # latent confounder with 2 proxies | |
U -> tx | |
U -> y | |
tx -> y | |
''' | |
import numpyro | |
from jax import numpy as np, random | |
from jax.scipy.special import logsumexp | |
from numpyro import distributions as dist | |
from numpyro.distributions import constraints | |
from numpyro.handlers import mask, substitute, trace, seed, condition | |
from numpyro.infer.mcmc import NUTS, MCMC | |
from numpyro.infer.util import log_likelihood | |
from numpyro import diagnostics | |
import arviz as az | |
import pandas as pd | |
import re | |
numpyro.set_host_device_count(4) | |
def latentconfoundermodel1d( | |
control={'N': 500, 'sample_priors': True}, | |
data={'tx': None, 'W1': None, 'W2': None, 'y': None}, | |
priors=None, prm_vals=None): | |
''' | |
a probabilistic model for linear regression with a latent confounder | |
:param data dict: a dictionary with np.ndarrays for all observed data that is used to condition on, and None for unobserved / to marginalize out data. | |
:param priors dict: priors keyed by parameter name | |
:param control dict: control arguments, like N | |
''' | |
# get global parameters | |
if control['sample_priors']: | |
prms = {prm_name: numpyro.sample(prm_name, prior) for prm_name, prior in priors.items()} | |
else: | |
prms = prm_vals | |
# data plate | |
with numpyro.plate('obs', control['N']): | |
Uhat = numpyro.sample('Uhat', dist.Normal(0,1)) | |
# U -> W model | |
muhat_W1 = Uhat * prms['b_U_W1'] | |
muhat_W2 = Uhat * prms['b_U_W2'] | |
numpyro.sample('W1', dist.Normal(muhat_W1, prms['s_W1']), obs=data['W1']) | |
numpyro.sample('W2', dist.Normal(muhat_W2, prms['s_W2']), obs=data['W2']) | |
# U -> tx model | |
muhat_tx = Uhat * prms['b_U_tx'] | |
tx = numpyro.sample('tx', dist.Normal(muhat_tx, prms['s_tx']), obs=data['tx']) | |
# outcome model for the linear predictor | |
muhat_y = Uhat * prms['b_U_y'] + tx * prms['b_tx_y'] | |
# sample outcome | |
return numpyro.sample('y', dist.Normal(muhat_y, prms['s_y']), obs=data['y']) | |
## sample data | |
prm_vals = dict( | |
b_U_W1 = 0.5, | |
b_U_W2 = 0.5, | |
s_W1 = 0.1, | |
s_W2 = 0.2, | |
b_U_tx = 0.75, | |
s_tx = 0.2, | |
b_tx_y = 1.0, | |
b_U_y = 0.75, | |
s_y = 0.2 | |
) | |
prm_priors = {k: dist.Normal(0,5) for k in prm_vals.keys()} | |
# fix some parameters including the latent confounder to positive values for identification: | |
pos_prms = ['b_U_W1', 'b_U_W2', 'b_U_tx', 'b_U_y'] | |
for prm in pos_prms: | |
prm_priors[prm] = dist.HalfNormal(2.5) | |
def sim_from_model(rng_key, model, prm_vals, nsim=500): | |
control = dict(N=nsim, sample_priors=False) | |
# run model forward | |
tr = trace(seed(model, rng_key)).get_trace(control=control, prm_vals=prm_vals) | |
# make dictionary | |
data = {k: v['value'] for k, v in tr.items()} | |
return data | |
nsim = 1000 | |
sim_keys = random.split(random.PRNGKey(1224), 2) | |
simdata = sim_from_model(sim_keys[0], latentconfoundermodel1d, prm_vals, nsim) | |
# create test data | |
simdata2 = sim_from_model(sim_keys[1], latentconfoundermodel1d, prm_vals, nsim) | |
testdata = simdata2.copy() | |
testdata['y'] = None # test data should not contain y | |
control = {'N': nsim, 'sample_priors': True} | |
num_samples = 3000 | |
num_warmup = 1500 | |
num_chains = 4 | |
## do mcmc | |
def run_mcmc(key, model, *args, **kwargs): | |
kernel = NUTS(model, target_accept_prob = 0.95) | |
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, progress_bar=True) | |
mcmc.run(key, *args, **kwargs) | |
return mcmc | |
mcmc_keys = random.split(random.PRNGKey(1225), 2) | |
mcmc = run_mcmc(mcmc_keys[0], latentconfoundermodel1d, control=control, data=simdata, priors=prm_priors) | |
smps = mcmc.get_samples(group_by_chain=True) | |
allvars = list(smps.keys()) | |
prmvars = [k for k in allvars if k not in ['Uhat']] | |
sum_ = pd.DataFrame(diagnostics.summary(smps)).T | |
print(sum_.loc[prmvars,]) | |
divergences = mcmc.get_extra_fields()['diverging'] | |
print(f"num divergences: {divergences.sum()}") | |
## now do posterior prediction on a new datase | |
# create model that fixes all fixed arguments in the posterior mode | |
def make_jittable(model, control, data, priors): | |
def newmodel(*args, **kwargs): | |
model(control, data, priors, *args, **kwargs) | |
return newmodel | |
smps2 = mcmc.get_samples(group_by_chain=False) | |
pp_control = control.copy() | |
pp_control['sample_priors'] = False | |
# script for getting posterior samples | |
def get_postpred_smps(key, model, control, data, priors, postsamples, num_draws=10, num_warmup=100, num_samples=100): | |
jittable_model = make_jittable(latentconfoundermodel1d, control, data, priors) | |
mcmc = MCMC(NUTS(jittable_model), num_warmup, num_samples, num_chains = 1, jit_model_args=True, progress_bar=False) | |
keys = random.split(key, num_draws) | |
draws = [] | |
for i in range(num_draws): | |
print(i, end='') | |
smp = {k: v[i] for k, v in postsamples.items()} | |
mcmc.run(keys[i], prm_vals=smp) | |
postsmps = mcmc.get_samples() | |
postsmp = {k: v[-1] for k, v in postsmps.items()} # grab last sample of each run | |
draws.append(postsmp) | |
return draws | |
pp_smps = get_postpred_smps(mcmc_keys[1], | |
latentconfoundermodel1d, | |
pp_control, | |
testdata, | |
prm_priors, | |
smps2, | |
num_draws=100) | |
## convert of type: {varname: (num_posterior_draws, N)} | |
def list_of_dicts_to_dict_of_lists(LD): | |
v = {k: [dic[k] for dic in LD] for k in LD[0]} | |
return v | |
pp_smps = list_of_dicts_to_dict_of_lists(pp_smps) | |
pp_smps = {k: np.stack(v, axis=0) for k, v in pp_smps.items()} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment