Last active
December 28, 2020 13:14
-
-
Save fehiepsi/b4a5a80b245600b99467a0264be05fd5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from collections import namedtuple\n", | |
"import copy\n", | |
"\n", | |
"from jax import device_put, lax, random\n", | |
"import jax.numpy as jnp\n", | |
"\n", | |
"import numpyro\n", | |
"import numpyro.distributions as dist\n", | |
"from numpyro.handlers import substitute, trace, seed\n", | |
"from numpyro.infer import MCMC, NUTS, log_likelihood\n", | |
"from numpyro.infer.mcmc import MCMCKernel\n", | |
"from numpyro.util import identity" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"HMC_ECS_State = namedtuple(\"HMC_ECS_State\", \"uz, hmc_state, accept_prob, rng_key\")\n", | |
"\"\"\"\n", | |
" - **uz** - a dict of current subsample indices and the current latent values\n", | |
" - **hmc_state** - current hmc_state\n", | |
" - **accept_prob** - acceptance probability of the proposal subsample indices\n", | |
" - **rng_key** - random key to generate new subsample indices\n", | |
"\"\"\"\n", | |
"\n", | |
"def _wrap_model(model):\n", | |
" def fn(*args, **kwargs):\n", | |
" subsample_values = kwargs.pop(\"_subsample_sites\", {})\n", | |
" with substitute(data=subsample_values):\n", | |
" model(*args, **kwargs)\n", | |
"\n", | |
" return fn\n", | |
"\n", | |
"\n", | |
"class HMC_ECS(MCMCKernel):\n", | |
" sample_field = \"uz\"\n", | |
"\n", | |
" def __init__(self, inner_kernel):\n", | |
" self.inner_kernel = copy.copy(inner_kernel)\n", | |
" self.inner_kernel._model = _wrap_model(inner_kernel.model)\n", | |
" self._plate_sizes = None\n", | |
"\n", | |
" @property\n", | |
" def model(self):\n", | |
" return self.inner_kernel._model\n", | |
"\n", | |
" def postprocess_fn(self, args, kwargs):\n", | |
" def fn(uz):\n", | |
" z = {k: v for k, v in uz.items() if k not in self._plate_sizes}\n", | |
" return self.inner_kernel.postprocess_fn(args, kwargs)(z)\n", | |
"\n", | |
" return fn\n", | |
"\n", | |
" def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):\n", | |
" model_kwargs = {} if model_kwargs is None else model_kwargs.copy()\n", | |
" rng_key, key_u, key_z = random.split(rng_key, 3)\n", | |
" prototype_trace = trace(seed(self.model, key_u)).get_trace(*model_args, **model_kwargs)\n", | |
" u = {name: site[\"value\"] for name, site in prototype_trace.items()\n", | |
" if site[\"type\"] == \"plate\" and site[\"args\"][0] > site[\"args\"][1]}\n", | |
" self._plate_sizes = {name: prototype_trace[name][\"args\"] for name in u}\n", | |
" model_kwargs[\"_subsample_sites\"] = u\n", | |
" hmc_state = self.inner_kernel.init(key_z, num_warmup, init_params,\n", | |
" model_args, model_kwargs)\n", | |
" uz = {**u, **hmc_state.z}\n", | |
" return device_put(HMC_ECS_State(uz, hmc_state, 1., rng_key))\n", | |
"\n", | |
" def sample(self, state, model_args, model_kwargs):\n", | |
" model_kwargs = {} if model_kwargs is None else model_kwargs.copy()\n", | |
" rng_key, key_u = random.split(state.rng_key)\n", | |
" u = {k: v for k, v in state.uz.items() if k in self._plate_sizes}\n", | |
" u_new = {}\n", | |
" for name, (size, subsample_size) in self._plate_sizes.items():\n", | |
" key_u, subkey = random.split(key_u)\n", | |
" u_new[name] = random.choice(subkey, size, (subsample_size,), replace=False)\n", | |
" sample = self.postprocess_fn(model_args, model_kwargs)(state.hmc_state.z)\n", | |
" u_loglik = log_likelihood(self.model, sample, *model_args, batch_ndims=0,\n", | |
" **model_kwargs, _subsample_sites=u)\n", | |
" u_loglik = sum(v.sum() for v in u_loglik.values())\n", | |
" u_new_loglik = log_likelihood(self.model, sample, *model_args, batch_ndims=0,\n", | |
" **model_kwargs, _subsample_sites=u_new)\n", | |
" u_new_loglik = sum(v.sum() for v in u_new_loglik.values())\n", | |
" accept_prob = jnp.clip(jnp.exp(u_new_loglik - u_loglik), a_max=1.0)\n", | |
" u = lax.cond(random.bernoulli(key_u, accept_prob), u_new, identity, u, identity)\n", | |
" model_kwargs[\"_subsample_sites\"] = u\n", | |
" hmc_state = self.inner_kernel.sample(state.hmc_state, model_args, model_kwargs)\n", | |
" uz = {**u, **hmc_state.z}\n", | |
" return HMC_ECS_State(uz, hmc_state, accept_prob, rng_key)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"sample: 100%|██████████| 1000/1000 [00:11<00:00, 90.54it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
" mean std median 5.0% 95.0% n_eff r_hat\n", | |
" x 1.01 0.01 1.01 0.99 1.02 180.80 1.02\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"def model(data):\n", | |
" x = numpyro.sample(\"x\", dist.Normal(0, 1))\n", | |
" with numpyro.plate(\"N\", data.shape[0], subsample_size=100):\n", | |
" batch = numpyro.subsample(data, event_dim=0)\n", | |
" numpyro.sample(\"obs\", dist.Normal(x, 1), obs=batch)\n", | |
"\n", | |
"kernel = HMC_ECS(NUTS(model))\n", | |
"mcmc = MCMC(kernel, 500, 500)\n", | |
"data = random.normal(random.PRNGKey(1), (10000,)) + 1\n", | |
"mcmc.run(random.PRNGKey(0), data, extra_fields=(\"accept_prob\",))\n", | |
"# there is a bug when exclude_deterministic=True, which will be fixed upstream\n", | |
"mcmc.print_summary(exclude_deterministic=False)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment