Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Created August 18, 2019 03:17
Show Gist options
  • Save fehiepsi/f0b1b4b2e987d8dde3993adceaa52063 to your computer and use it in GitHub Desktop.
Save fehiepsi/f0b1b4b2e987d8dde3993adceaa52063 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os; os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=24'\n",
"\n",
"from jax import pmap, random\n",
"from jax.config import config; config.update('jax_platform_name', 'cpu')\n",
"import jax.numpy as np\n",
"\n",
"import numpyro.distributions as dist\n",
"from numpyro.handlers import sample\n",
"from numpyro.hmc_util import initialize_model\n",
"from numpyro.mcmc import hmc\n",
"from numpyro.util import fori_collect"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def model(dim):\n",
" sample('x', dist.Normal(np.zeros(dim)))\n",
"\n",
"\n",
"def mcmc(rng, step_size=1, dim=5, adapt_step_size=False, num_warmup=0, num_samples=1000, num_chains=24):\n",
" rng_init, rng_hmc = random.split(rng)\n",
" init_params, potential_fn, _ = initialize_model(random.split(rng_init, num_chains), model, dim)\n",
"\n",
" def single_chain_mcmc(rng, init_params): \n",
" init_kernel, sample_kernel = hmc(potential_fn)\n",
" hmc_state = init_kernel(init_params, num_warmup, step_size=step_size,\n",
" adapt_step_size=adapt_step_size,\n",
" run_warmup=False, adapt_mass_matrix=False, rng=rng)\n",
" results = fori_collect(0, num_warmup + num_samples, sample_kernel, hmc_state,\n",
" progbar=False)\n",
" return results\n",
"\n",
" return pmap(single_chain_mcmc)(random.split(rng_hmc, num_chains), init_params)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 24 chains sampling 5 dimensional gaussian with different step_sizes"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step size 0.01, max leapfrog steps: 511\n",
"Step size 0.03, max leapfrog steps: 127\n",
"Step size 0.05, max leapfrog steps: 63\n",
"Step size 0.07, max leapfrog steps: 63\n",
"Step size 0.09, max leapfrog steps: 63\n",
"Step size 0.11, max leapfrog steps: 31\n",
"Step size 0.13, max leapfrog steps: 31\n",
"Step size 0.15, max leapfrog steps: 31\n",
"Step size 0.17, max leapfrog steps: 31\n",
"Step size 0.19, max leapfrog steps: 31\n",
"Step size 0.21, max leapfrog steps: 15\n",
"Step size 0.23, max leapfrog steps: 15\n",
"Step size 0.25, max leapfrog steps: 15\n",
"Step size 0.27, max leapfrog steps: 15\n",
"Step size 0.29, max leapfrog steps: 15\n",
"Step size 0.31, max leapfrog steps: 15\n",
"Step size 0.33, max leapfrog steps: 15\n",
"Step size 0.35, max leapfrog steps: 15\n",
"Step size 0.37, max leapfrog steps: 15\n",
"Step size 0.39, max leapfrog steps: 15\n",
"Step size 0.41, max leapfrog steps: 15\n",
"Step size 0.43, max leapfrog steps: 15\n",
"Step size 0.45, max leapfrog steps: 7\n",
"Step size 0.47, max leapfrog steps: 7\n",
"Step size 0.49, max leapfrog steps: 7\n",
"Step size 0.51, max leapfrog steps: 7\n",
"Step size 0.53, max leapfrog steps: 7\n",
"Step size 0.55, max leapfrog steps: 7\n",
"Step size 0.57, max leapfrog steps: 7\n",
"Step size 0.59, max leapfrog steps: 7\n",
"Step size 0.61, max leapfrog steps: 7\n",
"Step size 0.63, max leapfrog steps: 7\n",
"Step size 0.65, max leapfrog steps: 7\n",
"Step size 0.67, max leapfrog steps: 7\n",
"Step size 0.69, max leapfrog steps: 7\n",
"Step size 0.71, max leapfrog steps: 7\n",
"Step size 0.73, max leapfrog steps: 7\n",
"Step size 0.75, max leapfrog steps: 7\n",
"Step size 0.77, max leapfrog steps: 7\n",
"Step size 0.79, max leapfrog steps: 7\n",
"Step size 0.81, max leapfrog steps: 7\n",
"Step size 0.83, max leapfrog steps: 7\n",
"Step size 0.85, max leapfrog steps: 7\n",
"Step size 0.87, max leapfrog steps: 7\n",
"Step size 0.89, max leapfrog steps: 7\n",
"Step size 0.91, max leapfrog steps: 7\n",
"Step size 0.93, max leapfrog steps: 7\n",
"Step size 0.95, max leapfrog steps: 7\n",
"Step size 0.97, max leapfrog steps: 7\n",
"Step size 0.99, max leapfrog steps: 3\n",
"Step size 1.02, max leapfrog steps: 3\n",
"Step size 1.04, max leapfrog steps: 3\n",
"Step size 1.06, max leapfrog steps: 3\n",
"Step size 1.08, max leapfrog steps: 3\n",
"Step size 1.10, max leapfrog steps: 3\n",
"Step size 1.12, max leapfrog steps: 3\n",
"Step size 1.14, max leapfrog steps: 3\n",
"Step size 1.16, max leapfrog steps: 3\n",
"Step size 1.18, max leapfrog steps: 3\n",
"Step size 1.20, max leapfrog steps: 3\n",
"Step size 1.22, max leapfrog steps: 3\n",
"Step size 1.24, max leapfrog steps: 3\n",
"Step size 1.26, max leapfrog steps: 3\n",
"Step size 1.28, max leapfrog steps: 3\n",
"Step size 1.30, max leapfrog steps: 3\n",
"Step size 1.32, max leapfrog steps: 3\n",
"Step size 1.34, max leapfrog steps: 3\n",
"Step size 1.36, max leapfrog steps: 3\n",
"Step size 1.38, max leapfrog steps: 3\n",
"Step size 1.40, max leapfrog steps: 3\n",
"Step size 1.42, max leapfrog steps: 3\n",
"Step size 1.44, max leapfrog steps: 3\n",
"Step size 1.46, max leapfrog steps: 3\n",
"Step size 1.48, max leapfrog steps: 3\n",
"Step size 1.50, max leapfrog steps: 3\n",
"Step size 1.52, max leapfrog steps: 3\n",
"Step size 1.54, max leapfrog steps: 3\n",
"Step size 1.56, max leapfrog steps: 3\n",
"Step size 1.58, max leapfrog steps: 3\n",
"Step size 1.60, max leapfrog steps: 3\n",
"Step size 1.62, max leapfrog steps: 3\n",
"Step size 1.64, max leapfrog steps: 3\n",
"Step size 1.66, max leapfrog steps: 3\n",
"Step size 1.68, max leapfrog steps: 3\n",
"Step size 1.70, max leapfrog steps: 3\n",
"Step size 1.72, max leapfrog steps: 3\n",
"Step size 1.74, max leapfrog steps: 3\n",
"Step size 1.76, max leapfrog steps: 3\n",
"Step size 1.78, max leapfrog steps: 3\n",
"Step size 1.80, max leapfrog steps: 3\n",
"Step size 1.82, max leapfrog steps: 3\n",
"Step size 1.84, max leapfrog steps: 3\n",
"Step size 1.86, max leapfrog steps: 3\n",
"Step size 1.88, max leapfrog steps: 3\n",
"Step size 1.90, max leapfrog steps: 3\n",
"Step size 1.92, max leapfrog steps: 3\n",
"Step size 1.94, max leapfrog steps: 1\n",
"Step size 1.96, max leapfrog steps: 1\n",
"Step size 1.98, max leapfrog steps: 1\n",
"Step size 2.00, max leapfrog steps: 1\n"
]
}
],
"source": [
"for seed, step_size in enumerate(np.linspace(0.01, 2, 100)):\n",
" chain_results = mcmc(random.PRNGKey(seed), step_size)\n",
" print(\"Step size {:.2f}, max leapfrog steps: {}\"\n",
" .format(step_size, np.max(chain_results.num_steps)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 24 chains sampling 400 dimensional gaussian with adapt_step_size"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"max leapfrog steps: 63\n"
]
}
],
"source": [
"chain_results = mcmc(random.PRNGKey(seed), dim=400, adapt_step_size=True, num_warmup=500, num_samples=0)\n",
"print(\"max leapfrog steps: {}\".format(np.max(chain_results.num_steps)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment