Created
April 23, 2020 08:17
-
-
Save vanAmsterdam/57c65208bb997a1a47cf207302e4812c to your computer and use it in GitHub Desktop.
Latent confounder treatment effect estimation model in Numpyro
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
define and run a latent variable model | |
''' | |
import numpyro | |
from jax import numpy as np, random | |
from jax.scipy.special import logsumexp | |
from numpyro import distributions as dist | |
from numpyro.distributions import constraints | |
from numpyro.handlers import mask, substitute, trace, seed | |
from numpyro.infer.mcmc import NUTS, MCMC | |
from numpyro.infer.util import log_likelihood | |
from numpyro import diagnostics | |
import pandas as pd | |
# for parallel sampling on cpu: | |
numpyro.set_host_device_count(4) | |
def latentconfoundermodel1d(data={'tx': None, 'W1': None, 'W2': None, 'y': None}, | |
control={'N': 500, 'symmetrymethod': 'ordered_W'}): | |
''' | |
a probabilistic model for linear regression with a latent confounder | |
:param data dict: a dictionary with np.ndarrays for all observed data that is used to condition on, and None for unobserved / to marginalize out data. | |
''' | |
# get global parameters | |
## latent factor proxies | |
mu_W1 = numpyro.sample('mu_W1', dist.Normal(0, 5)) | |
mu_W2 = numpyro.sample('mu_W2', dist.Normal(0, 5)) | |
### assume positive association for breaking symmetries | |
if control['symmetrymethod'] == 'ordered_W': | |
b_U_W = numpyro.param('b_U_W', np.array([0.5, 0.5]), constraint=constraints.ordered_vector) | |
numpyro.sample('b_U_W_obs', dist.Normal(0, 5), obs=b_U_W) | |
b_U_W1 = b_U_W[0] | |
b_U_W2 = b_U_W[1] | |
elif control['symmetrymethod'] == 'positive': | |
b_U_W1 = numpyro.sample('b_U_W1', dist.HalfNormal(2.5)) | |
b_U_W2 = numpyro.sample('b_U_W2', dist.HalfNormal(2.5)) | |
elif control['symmetrymethod'] == 'none': | |
b_U_W1 = numpyro.sample('b_U_W1', dist.Normal(0, 5)) | |
b_U_W2 = numpyro.sample('b_U_W2', dist.Normal(0, 5)) | |
else: | |
raise NotImplementedError(f"wrong symmetrymethod: {control['symmetrymethod']}, choose from 'none', 'ordered_W' or 'positive'") | |
## treatment model | |
mu_tx = numpyro.sample('mu_tx', dist.Normal(0, 5)) | |
if control['symmetrymethod'] == 'positive': | |
b_U_tx = numpyro.sample('b_U_tx', dist.HalfNormal(2.5)) | |
else: | |
b_U_tx = numpyro.sample('b_U_tx', dist.Normal(0, 5)) | |
## outcome model | |
b_tx_y = numpyro.sample('b_tx_y', dist.Normal(0, 5)) | |
if control['symmetrymethod'] == 'positive': | |
b_U_y = numpyro.sample('b_U_y', dist.HalfNormal(2.5)) | |
else: | |
b_U_y = numpyro.sample('b_U_y', dist.Normal(0, 5)) | |
mu_y = numpyro.sample('mu_y', dist.Normal(0, 5)) | |
s_y = numpyro.sample('s_y', dist.HalfCauchy(2.5)) | |
# data plate | |
with numpyro.plate('obs', control['N']): | |
Uhat = numpyro.sample('Uhat', dist.Normal(0,1)) | |
# U -> W model (logitistic regression) | |
invlogit_W1 = Uhat * b_U_W1 - mu_W1 | |
invlogit_W2 = Uhat * b_U_W2 - mu_W2 | |
numpyro.sample('W1', dist.Bernoulli(logits=invlogit_W1), obs=data['W1']) | |
numpyro.sample('W2', dist.Bernoulli(logits=invlogit_W2), obs=data['W2']) | |
# U -> tx model (logistic regression) | |
invlogit_tx = Uhat * b_U_tx - mu_tx | |
tx = numpyro.sample('tx', dist.Bernoulli(logits=invlogit_tx), obs=data['tx']) | |
# outcome model for the linear predictor | |
mu_y_hat = Uhat * b_U_y + b_tx_y * tx - mu_y | |
# sample outcome | |
return numpyro.sample('y', dist.Normal(mu_y_hat, s_y), obs=data['y']) | |
## sample data | |
nsim = 500 | |
prm_vals = dict( | |
mu_W1 = 0.25, | |
mu_W2 = -0.25, | |
b_U_W1 = 0.5, | |
b_U_W2 = 1.25, | |
mu_tx = -0.25, | |
b_U_tx = 0.75, | |
b_tx_y = 1.0, | |
b_U_y = 1.0, | |
mu_y = 0.0, | |
s_y = 0.1 | |
) | |
prm_vals['b_U_W'] = np.array([prm_vals['b_U_W1'], prm_vals['b_U_W2']]) | |
def sim_from_model(rng_key, model, prm_vals, nsim=500): | |
control = dict(N=nsim, symmetrymethod='none') | |
# substitute parameters | |
model = substitute(model, prm_vals) | |
# run model forward | |
tr = trace(seed(model, rng_key)).get_trace(control=control) | |
# make dictionary | |
data = {k: v['value'] for k, v in tr.items()} | |
return data | |
sim_key = random.PRNGKey(1223) | |
num_samples = 3000 | |
num_warmup = 1500 | |
num_chains = 4 | |
simdata = sim_from_model(sim_key, latentconfoundermodel1d, prm_vals, nsim) | |
## mcmc helper function | |
def run_mcmc(key, model, control): | |
kernel = NUTS(latentconfoundermodel1d, target_accept_prob = 0.99) | |
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, progress_bar=True) | |
mcmc.run(key, simdata, control, extra_fields=('potential_energy',)) | |
return mcmc | |
mcmc_key = random.PRNGKey(1224) | |
controls = { | |
'none': {'N': nsim, 'symmetrymethod': 'none'}, | |
'ordered_W': {'N': nsim, 'symmetrymethod': 'ordered_W'}, | |
'positive': {'N': nsim, 'symmetrymethod': 'positive'} | |
} | |
# create holders | |
mcmcs = {} # mcmc objects | |
sums = {} # summaries | |
lds = {} # marginal log densities | |
divs = {} # divergent samples | |
for symmetrymethod, control in controls.items(): | |
mcmc_key, mcmc_key_ = random.split(mcmc_key) | |
mcmc = run_mcmc(mcmc_key_, latentconfoundermodel1d, control) | |
smps = mcmc.get_samples(group_by_chain=True) | |
allvars = list(smps.keys()) | |
prmvars = [k for k in allvars if k not in ['Uhat']] | |
sum_ = pd.DataFrame(diagnostics.summary(smps)).T | |
pe = mcmc.get_extra_fields()['potential_energy'] | |
ld = np.mean(-pe) | |
divergences = mcmc.get_extra_fields()['diverging'] | |
print(f"method: {symmetrymethod}") | |
print(f"num divergences: {divergences.sum()}") | |
print(f"expected log density: {ld:.2f}") | |
print(sum_.loc[prmvars,]) | |
mcmcs[symmetrymethod] = mcmc | |
sums[symmetrymethod] = sum_.loc[prmvars,] | |
lds[symmetrymethod] = ld | |
divs[symmetrymethod] = divergences | |
# print(pd.concat(sums)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment