Last active
January 29, 2022 16:23
-
-
Save aflaxman/c2c1b343d1550bda20fc813407e1c81d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np, pandas as pd, matplotlib.pyplot as plt | |
from collections import namedtuple | |
from jax import random, lax | |
from jax.flatten_util import ravel_pytree | |
import jax.numpy as jnp | |
import numpyro | |
import numpyro.distributions as dist | |
from numpyro.infer import MCMC, util, init_to_sample | |
from numpyro.infer.mcmc import MCMCKernel | |
ABCState = namedtuple("ABCState", ["z", "rng_key"]) | |
class ABC(MCMCKernel): | |
def __init__(self, model, data, threshold, summary_statistic, max_attempts_per_sample | |
): | |
self._model = model | |
self._data = data | |
self._predictive = util.Predictive(self._model, num_samples=1) | |
self._threshold = jnp.array(threshold) | |
self._summary_statistic = summary_statistic | |
self._max_attempts_per_sample = max_attempts_per_sample | |
@property | |
def sample_field(self): | |
return "z" | |
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): | |
assert rng_key.ndim == 1, "only non-vectorized, for now" | |
proposal = self._predictive(rng_key, *model_args, **model_kwargs) | |
return ABCState(proposal, rng_key) | |
def sample(self, state, model_args, model_kwargs): | |
def while_condition_func(val): | |
distance, rng_key, proposal, n = val | |
return jnp.logical_and(distance > self._threshold, | |
n < self._max_attempts_per_sample) | |
def while_body_func(val): | |
distance, rng_key, proposal, n = val | |
rng_key, sample_key = random.split(rng_key) | |
proposal = self._predictive(sample_key, *model_args, **model_kwargs) | |
# FIXME: need to resample the values of the observed vars here | |
distance = self._summary_statistic(self._data, proposal) | |
return (distance, rng_key, proposal, n+1) | |
distance, rng_key, proposal, n = \ | |
lax.while_loop(while_condition_func, | |
while_body_func, | |
(jnp.inf, # distance | |
state.rng_key, # rng_key | |
state.z, # proposal | |
0 # iteration | |
)) | |
proposal['theta'] = jnp.where(distance <= self._threshold, proposal['theta'], state.z['theta']) | |
return ABCState(proposal, rng_key) | |
def my_model(): | |
with numpyro.plate('I', 4): | |
theta = numpyro.sample('theta', dist.Uniform(-10, 10)) | |
def sum_exceeds_threshold(threshold, proposal): | |
return jnp.where(proposal['theta'].sum() > threshold, 0, jnp.inf) | |
def my_run(model): | |
rng_key = random.PRNGKey(12345) | |
sum_lower_bound = jnp.array(-1) | |
kernel = ABC(model, | |
data=sum_lower_bound, threshold=1, | |
summary_statistic=sum_exceeds_threshold, | |
max_attempts_per_sample=1_000) | |
mcmc = MCMC(kernel, num_warmup=0, num_samples=100, thinning=1) | |
mcmc.run(rng_key) | |
posterior_samples = mcmc.get_samples() | |
plt.plot(posterior_samples['theta'][:,0,:].sum(axis=1), label='trace') | |
plt.ylabel('theta') | |
plt.axhline(sum_lower_bound, linestyle='dashed', color='k', label='lower bound') | |
my_run(my_model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment