Skip to content

Instantly share code, notes, and snippets.

@vanAmsterdam
Created April 23, 2020 08:17
Show Gist options
  • Save vanAmsterdam/57c65208bb997a1a47cf207302e4812c to your computer and use it in GitHub Desktop.
Save vanAmsterdam/57c65208bb997a1a47cf207302e4812c to your computer and use it in GitHub Desktop.
Latent confounder treatment effect estimation model in Numpyro
'''
define and run a latent variable model
'''
import numpyro
from jax import numpy as np, random
from jax.scipy.special import logsumexp
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.handlers import mask, substitute, trace, seed
from numpyro.infer.mcmc import NUTS, MCMC
from numpyro.infer.util import log_likelihood
from numpyro import diagnostics
import pandas as pd
# for parallel sampling on cpu:
numpyro.set_host_device_count(4)
def latentconfoundermodel1d(data={'tx': None, 'W1': None, 'W2': None, 'y': None},
control={'N': 500, 'symmetrymethod': 'ordered_W'}):
'''
a probabilistic model for linear regression with a latent confounder
:param data dict: a dictionary with np.ndarrays for all observed data that is used to condition on, and None for unobserved / to marginalize out data.
'''
# get global parameters
## latent factor proxies
mu_W1 = numpyro.sample('mu_W1', dist.Normal(0, 5))
mu_W2 = numpyro.sample('mu_W2', dist.Normal(0, 5))
### assume positive association for breaking symmetries
if control['symmetrymethod'] == 'ordered_W':
b_U_W = numpyro.param('b_U_W', np.array([0.5, 0.5]), constraint=constraints.ordered_vector)
numpyro.sample('b_U_W_obs', dist.Normal(0, 5), obs=b_U_W)
b_U_W1 = b_U_W[0]
b_U_W2 = b_U_W[1]
elif control['symmetrymethod'] == 'positive':
b_U_W1 = numpyro.sample('b_U_W1', dist.HalfNormal(2.5))
b_U_W2 = numpyro.sample('b_U_W2', dist.HalfNormal(2.5))
elif control['symmetrymethod'] == 'none':
b_U_W1 = numpyro.sample('b_U_W1', dist.Normal(0, 5))
b_U_W2 = numpyro.sample('b_U_W2', dist.Normal(0, 5))
else:
raise NotImplementedError(f"wrong symmetrymethod: {control['symmetrymethod']}, choose from 'none', 'ordered_W' or 'positive'")
## treatment model
mu_tx = numpyro.sample('mu_tx', dist.Normal(0, 5))
if control['symmetrymethod'] == 'positive':
b_U_tx = numpyro.sample('b_U_tx', dist.HalfNormal(2.5))
else:
b_U_tx = numpyro.sample('b_U_tx', dist.Normal(0, 5))
## outcome model
b_tx_y = numpyro.sample('b_tx_y', dist.Normal(0, 5))
if control['symmetrymethod'] == 'positive':
b_U_y = numpyro.sample('b_U_y', dist.HalfNormal(2.5))
else:
b_U_y = numpyro.sample('b_U_y', dist.Normal(0, 5))
mu_y = numpyro.sample('mu_y', dist.Normal(0, 5))
s_y = numpyro.sample('s_y', dist.HalfCauchy(2.5))
# data plate
with numpyro.plate('obs', control['N']):
Uhat = numpyro.sample('Uhat', dist.Normal(0,1))
# U -> W model (logitistic regression)
invlogit_W1 = Uhat * b_U_W1 - mu_W1
invlogit_W2 = Uhat * b_U_W2 - mu_W2
numpyro.sample('W1', dist.Bernoulli(logits=invlogit_W1), obs=data['W1'])
numpyro.sample('W2', dist.Bernoulli(logits=invlogit_W2), obs=data['W2'])
# U -> tx model (logistic regression)
invlogit_tx = Uhat * b_U_tx - mu_tx
tx = numpyro.sample('tx', dist.Bernoulli(logits=invlogit_tx), obs=data['tx'])
# outcome model for the linear predictor
mu_y_hat = Uhat * b_U_y + b_tx_y * tx - mu_y
# sample outcome
return numpyro.sample('y', dist.Normal(mu_y_hat, s_y), obs=data['y'])
## sample data
nsim = 500
prm_vals = dict(
mu_W1 = 0.25,
mu_W2 = -0.25,
b_U_W1 = 0.5,
b_U_W2 = 1.25,
mu_tx = -0.25,
b_U_tx = 0.75,
b_tx_y = 1.0,
b_U_y = 1.0,
mu_y = 0.0,
s_y = 0.1
)
prm_vals['b_U_W'] = np.array([prm_vals['b_U_W1'], prm_vals['b_U_W2']])
def sim_from_model(rng_key, model, prm_vals, nsim=500):
control = dict(N=nsim, symmetrymethod='none')
# substitute parameters
model = substitute(model, prm_vals)
# run model forward
tr = trace(seed(model, rng_key)).get_trace(control=control)
# make dictionary
data = {k: v['value'] for k, v in tr.items()}
return data
sim_key = random.PRNGKey(1223)
num_samples = 3000
num_warmup = 1500
num_chains = 4
simdata = sim_from_model(sim_key, latentconfoundermodel1d, prm_vals, nsim)
## mcmc helper function
def run_mcmc(key, model, control):
kernel = NUTS(latentconfoundermodel1d, target_accept_prob = 0.99)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, progress_bar=True)
mcmc.run(key, simdata, control, extra_fields=('potential_energy',))
return mcmc
mcmc_key = random.PRNGKey(1224)
controls = {
'none': {'N': nsim, 'symmetrymethod': 'none'},
'ordered_W': {'N': nsim, 'symmetrymethod': 'ordered_W'},
'positive': {'N': nsim, 'symmetrymethod': 'positive'}
}
# create holders
mcmcs = {} # mcmc objects
sums = {} # summaries
lds = {} # marginal log densities
divs = {} # divergent samples
for symmetrymethod, control in controls.items():
mcmc_key, mcmc_key_ = random.split(mcmc_key)
mcmc = run_mcmc(mcmc_key_, latentconfoundermodel1d, control)
smps = mcmc.get_samples(group_by_chain=True)
allvars = list(smps.keys())
prmvars = [k for k in allvars if k not in ['Uhat']]
sum_ = pd.DataFrame(diagnostics.summary(smps)).T
pe = mcmc.get_extra_fields()['potential_energy']
ld = np.mean(-pe)
divergences = mcmc.get_extra_fields()['diverging']
print(f"method: {symmetrymethod}")
print(f"num divergences: {divergences.sum()}")
print(f"expected log density: {ld:.2f}")
print(sum_.loc[prmvars,])
mcmcs[symmetrymethod] = mcmc
sums[symmetrymethod] = sum_.loc[prmvars,]
lds[symmetrymethod] = ld
divs[symmetrymethod] = divergences
# print(pd.concat(sums))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment