Created
January 4, 2022 05:01
-
-
Save aflaxman/4e7ae00a24a3c798a7021721d2773d9d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
from jax import random | |
from jax.flatten_util import ravel_pytree | |
import jax.numpy as jnp | |
import numpyro | |
import numpyro.distributions as dist | |
from numpyro.infer import MCMC, util | |
from numpyro.infer.mcmc import MCMCKernel | |
MetState = namedtuple("MetState", ["z", "rng_key"]) # does it matter if it is called z or u? | |
class Metropolis(MCMCKernel): | |
def __init__(self, model, step_size=0.1): | |
self._model = model | |
self._step_size = step_size | |
@property | |
def sample_field(self): | |
return "z" | |
@property | |
def default_fields(self): | |
return ("z",) | |
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): | |
assert rng_key.ndim == 1, "only non-vectorized, for now" | |
return MetState(init_params, rng_key) | |
def sample(self, state, model_args, model_kwargs): | |
rng_key, key_proposal, key_accept = random.split(state.rng_key, 3) | |
z_flat, unravel_fn = ravel_pytree(state.z) | |
z_proposal = dist.Normal(z_flat, self._step_size).sample(key_proposal) | |
z_proposal_dict = unravel_fn(z_proposal) | |
log_pr_0, model_tr = util.log_density(self._model, model_args, model_kwargs, state.z) | |
log_pr_1, model_tr = util.log_density(self._model, model_args, model_kwargs, z_proposal_dict) | |
accept_prob = jnp.exp(log_pr_1 - log_pr_0) | |
z_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, z_proposal, z_flat) | |
return MetState(unravel_fn(z_new), rng_key) | |
def model(): | |
numpyro.sample('x', dist.Normal(0,1)) | |
rng_key = random.PRNGKey(12345) | |
kernel = Metropolis(model, step_size=1) | |
mcmc = MCMC(kernel, num_warmup=0, num_samples=200, thinning=1) | |
mcmc.run(rng_key, init_params={'x':0}) | |
posterior_samples = mcmc.get_samples() | |
mcmc.print_summary() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment