Created
September 20, 2019 21:27
-
-
Save fehiepsi/6226401d6960de647a2e1d58b640605b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"\n", | |
"import numpy as onp\n", | |
"from numpy.testing import assert_allclose\n", | |
"import pytest\n", | |
"\n", | |
"from jax import pmap, random, vmap\n", | |
"from jax.lib import xla_bridge\n", | |
"import jax.numpy as np\n", | |
"from jax.scipy.special import logit\n", | |
"from jax.config import config as jax_config; jax_config.update('jax_platform_name', 'gpu')\n", | |
"\n", | |
"import numpyro\n", | |
"import numpyro.distributions as dist\n", | |
"from numpyro.distributions import constraints\n", | |
"from numpyro.hmc_util import initialize_model\n", | |
"from numpyro.mcmc import hmc, mcmc\n", | |
"from numpyro.util import fori_collect" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"map_fn = vmap\n", | |
"N, dim = 10000, 100\n", | |
"warmup_steps, num_samples = 1000, 9000\n", | |
"data = random.normal(random.PRNGKey(0), (N, dim))\n", | |
"true_coefs = random.normal(random.PRNGKey(1), (dim,))\n", | |
"logits = np.sum(true_coefs * data, axis=-1)\n", | |
"labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(2))\n", | |
"\n", | |
"def model(labels):\n", | |
" coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))\n", | |
" logits = np.dot(data, coefs)\n", | |
" return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## num_chains = 10" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_chains = 10\n", | |
"rngs = random.split(random.PRNGKey(3), num_chains)\n", | |
"init_params, potential_fn, constrain_fn = initialize_model(rngs, model, labels)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')\n", | |
"init_kernel_map = map_fn(lambda init_param, rng: init_kernel(\n", | |
" init_param, num_warmup=warmup_steps, run_warmup=False, rng=rng))\n", | |
"init_states = init_kernel_map(init_params, rngs)\n", | |
"x = init_states.z.copy()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### fori vmap" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 10000/10000 [03:18<00:00, 57.11it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"gradients/s 4780.861593848532\n" | |
] | |
} | |
], | |
"source": [ | |
"import time\n", | |
"tic = time.time()\n", | |
"num_steps = fori_collect(0, num_samples + warmup_steps, vmap(sample_kernel), init_states,\n", | |
" transform=lambda x: x.num_steps, progbar=True).copy()\n", | |
"toc = time.time()\n", | |
"print(\"gradients/s\", num_steps.sum() / (toc - tic))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Note that among 3m, it took about 25s to compile." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### num_chains = 100" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 10000/10000 [19:35<00:00, 8.77it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"gradients/s 7243.712290338247\n" | |
] | |
} | |
], | |
"source": [ | |
"num_chains = 100\n", | |
"rngs = random.split(random.PRNGKey(3), num_chains)\n", | |
"init_params, potential_fn, constrain_fn = initialize_model(rngs, model, labels)\n", | |
"\n", | |
"init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')\n", | |
"init_kernel_map = map_fn(lambda init_param, rng: init_kernel(\n", | |
" init_param, num_warmup=warmup_steps, run_warmup=False, rng=rng))\n", | |
"init_states = init_kernel_map(init_params, rngs)\n", | |
"x = init_states.z.copy()\n", | |
"\n", | |
"import time\n", | |
"tic = time.time()\n", | |
"num_steps = fori_collect(0, num_samples + warmup_steps, vmap(sample_kernel), init_states,\n", | |
" transform=lambda x: x.num_steps, progbar=True).copy()\n", | |
"toc = time.time()\n", | |
"print(\"gradients/s\", num_steps.sum() / (toc - tic))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### num_chains = 1000" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 1000/1000 [18:05<00:00, 1.06it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"gradients/s 9721.04384608474\n" | |
] | |
} | |
], | |
"source": [ | |
"num_chains = 1000\n", | |
"warmup_steps, num_samples = 100, 900\n", | |
"rngs = random.split(random.PRNGKey(3), num_chains)\n", | |
"init_params, potential_fn, constrain_fn = initialize_model(rngs, model, labels)\n", | |
"\n", | |
"init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')\n", | |
"init_kernel_map = map_fn(lambda init_param, rng: init_kernel(\n", | |
" init_param, num_warmup=warmup_steps, run_warmup=False, rng=rng))\n", | |
"init_states = init_kernel_map(init_params, rngs)\n", | |
"x = init_states.z.copy()\n", | |
"\n", | |
"import time\n", | |
"tic = time.time()\n", | |
"num_steps = fori_collect(0, num_samples + warmup_steps, vmap(sample_kernel), init_states,\n", | |
" transform=lambda x: x.num_steps, progbar=True).copy()\n", | |
"toc = time.time()\n", | |
"print(\"gradients/s\", num_steps.sum() / (toc - tic))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### num_chains = 1000 in GPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 1000/1000 [01:09<00:00, 14.46it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"gradients/s 155579.4827417393\n" | |
] | |
} | |
], | |
"source": [ | |
"num_chains = 1000\n", | |
"warmup_steps, num_samples = 100, 900\n", | |
"rngs = random.split(random.PRNGKey(3), num_chains)\n", | |
"init_params, potential_fn, constrain_fn = initialize_model(rngs, model, labels)\n", | |
"\n", | |
"init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')\n", | |
"init_kernel_map = map_fn(lambda init_param, rng: init_kernel(\n", | |
" init_param, num_warmup=warmup_steps, run_warmup=False, rng=rng))\n", | |
"init_states = init_kernel_map(init_params, rngs)\n", | |
"x = init_states.z.copy()\n", | |
"\n", | |
"import time\n", | |
"tic = time.time()\n", | |
"num_steps = fori_collect(0, num_samples + warmup_steps, vmap(sample_kernel), init_states,\n", | |
" transform=lambda x: x.num_steps, progbar=True).copy()\n", | |
"toc = time.time()\n", | |
"print(\"gradients/s\", num_steps.sum() / (toc - tic))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment