Last active
April 8, 2019 05:08
-
-
Save fehiepsi/ab876d1db27c2277dc6c0cf1ab5c8ff0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import time\n", | |
"\n", | |
"N, dim = 3000, 3\n", | |
"warmup_steps, num_samples = 1000, 20000" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### numpyro" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import jax.numpy as np\n", | |
"import jax.random as random\n", | |
"from jax.scipy.special import expit\n", | |
"\n", | |
"import numpyro.distributions as dist\n", | |
"from numpyro.distributions.util import validation_disabled\n", | |
"from numpyro.handlers import sample\n", | |
"from numpyro.hmc_util import initialize_model\n", | |
"from numpyro.mcmc import hmc_kernel\n", | |
"from numpyro.util import tscan" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rng = random.PRNGKey(0)\n", | |
"data = random.normal(rng, (N, dim))\n", | |
"true_coefs = np.arange(1., dim + 1.)\n", | |
"logits = np.sum(true_coefs * data, axis=-1)\n", | |
"labels = dist.bernoulli(logits, is_logits=True).rvs(random_state=rng)\n", | |
"\n", | |
"def model(labels):\n", | |
" coefs = sample('coefs', dist.norm(np.zeros(dim), np.ones(dim)))\n", | |
" logits = np.sum(coefs * data, axis=-1)\n", | |
" return sample('obs', dist.bernoulli(logits, is_logits=True), obs=labels)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time for complie and init: 6.054578542709351\n", | |
"Time to compile: 1.8476543426513672\n", | |
"Time to compile and generate 20000 samples: 2.8188066482543945\n" | |
] | |
} | |
], | |
"source": [ | |
"with validation_disabled():\n", | |
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n", | |
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=\"HMC\")\n", | |
" start = time.time()\n", | |
" hmc_state = init_kernel(init_params,\n", | |
" step_size=0.1,\n", | |
" num_steps=15,\n", | |
" num_warmup_steps=warmup_steps)\n", | |
" print(\"Time for complie and init:\", time.time() - start)\n", | |
" start = time.time()\n", | |
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n", | |
" hmc_state, np.arange(1))\n", | |
" print(\"Time to compile:\", time.time() - start)\n", | |
" start = time.time()\n", | |
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n", | |
" hmc_state, np.arange(num_samples))\n", | |
" print(\"Time to compile and generate 20000 samples:\", time.time() - start)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time for complie and init: 11.951331853866577\n", | |
"Time to compile: 8.162105798721313\n", | |
"Time to compile and generate 20000 samples: 11.791397094726562\n" | |
] | |
} | |
], | |
"source": [ | |
"with validation_disabled():\n", | |
" init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n", | |
" init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=\"NUTS\")\n", | |
" start = time.time()\n", | |
" hmc_state = init_kernel(init_params,\n", | |
" step_size=0.1,\n", | |
" num_steps=15,\n", | |
" num_warmup_steps=warmup_steps)\n", | |
" print(\"Time for complie and init:\", time.time() - start)\n", | |
" start = time.time()\n", | |
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n", | |
" hmc_state, np.arange(1))\n", | |
" print(\"Time to compile:\", time.time() - start)\n", | |
" start = time.time()\n", | |
" hmc_states = tscan(lambda state, i: sample_kernel(state),\n", | |
" hmc_state, np.arange(num_samples))\n", | |
" print(\"Time to compile and generate 20000 samples:\", time.time() - start)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### pyro" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"import pyro\n", | |
"import pyro.distributions as pdist\n", | |
"from pyro.infer.mcmc import HMC, MCMC, NUTS" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pyro.set_rng_seed(0)\n", | |
"data = torch.randn(N, dim)\n", | |
"true_coefs = torch.arange(1., dim + 1.)\n", | |
"logits = (true_coefs * data).sum(-1)\n", | |
"labels = pdist.Bernoulli(logits=logits).sample()\n", | |
"\n", | |
"def model(data):\n", | |
" coefs = pyro.sample('beta', pdist.Normal(torch.zeros(dim), torch.ones(dim)))\n", | |
" logits = (coefs * data).sum(-1)\n", | |
" return pyro.sample('y', pdist.Bernoulli(logits=logits), obs=labels)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time for complie and init: 4.401107311248779\n", | |
"Time for complie, init, and run: 42.86292505264282\n" | |
] | |
} | |
], | |
"source": [ | |
"start = time.time()\n", | |
"hmc_kernel = HMC(model, step_size=0.1, num_steps=15, jit_compile=True, ignore_jit_warnings=True)\n", | |
"mcmc_run = MCMC(hmc_kernel, num_samples=1, warmup_steps=warmup_steps, disable_progbar=True).run(data)\n", | |
"print(\"Time for complie and init:\", time.time() - start)\n", | |
"start = time.time()\n", | |
"hmc_kernel = HMC(model, step_size=0.1, num_steps=15, jit_compile=True, ignore_jit_warnings=True)\n", | |
"mcmc_run = MCMC(hmc_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=True).run(data)\n", | |
"print(\"Time for complie, init, and run:\", time.time() - start)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time for complie and init: 5.244949817657471\n", | |
"Time for complie, init, and run: 118.32065653800964\n" | |
] | |
} | |
], | |
"source": [ | |
"start = time.time()\n", | |
"hmc_kernel = NUTS(model, step_size=0.1, jit_compile=True, ignore_jit_warnings=True)\n", | |
"mcmc_run = MCMC(hmc_kernel, num_samples=1, warmup_steps=warmup_steps, disable_progbar=True).run(data)\n", | |
"print(\"Time for complie and init:\", time.time() - start)\n", | |
"start = time.time()\n", | |
"hmc_kernel = NUTS(model, step_size=0.1, jit_compile=True, ignore_jit_warnings=True)\n", | |
"mcmc_run = MCMC(hmc_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=True).run(data)\n", | |
"print(\"Time for complie, init, and run:\", time.time() - start)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**So numpyro sampling is 30x faster than pyro.**" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**And we see that NUTS takes more than 3s to generate 20000 samples while HMC only takes 1s.**" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### let's see if nuts' implementation adds overhead over hmc" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(40000, dtype=int32)" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n", | |
"init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=\"HMC\")\n", | |
"hmc_state = init_kernel(init_params,\n", | |
" step_size=0.1,\n", | |
" num_steps=15,\n", | |
" num_warmup_steps=warmup_steps)\n", | |
"hmc_states = tscan(lambda state, i: sample_kernel(state),\n", | |
" hmc_state, np.arange(num_samples), fields=(3,))\n", | |
"np.sum(hmc_states.num_steps).copy()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(113032, dtype=int32)" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"init_params, potential_fn = initialize_model(random.PRNGKey(2), model, (labels,), {})\n", | |
"init_kernel, sample_kernel = hmc_kernel(potential_fn, algo=\"NUTS\")\n", | |
"hmc_state = init_kernel(init_params,\n", | |
" step_size=0.1,\n", | |
" num_steps=15,\n", | |
" num_warmup_steps=warmup_steps)\n", | |
"hmc_states = tscan(lambda state, i: sample_kernel(state),\n", | |
" hmc_state, np.arange(num_samples), fields=(3,))\n", | |
"np.sum(hmc_states.num_steps).copy()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**So NUTS spends 3x more verlet steps than HMC! -> Iterative NUTS just has a small overhead over HMC.**" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Sampling is pretty fast. The only issue now is the compiling time. In addition, we compile 2 times: one at init and one at scan." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment