Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Created September 21, 2019 02:09
Show Gist options
  • Save fehiepsi/b9ff5cb6df8b110172471084df1cbe99 to your computer and use it in GitHub Desktop.
Save fehiepsi/b9ff5cb6df8b110172471084df1cbe99 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"from jax import random\n",
"import jax.numpy as np\n",
"from jax.config import config; config.update('jax_platform_name', 'gpu')\n",
"\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from numpyro.mcmc import MCMC, NUTS"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"N, dim = 10000, 100\n",
"data = random.normal(random.PRNGKey(0), (N, dim))\n",
"true_coefs = random.normal(random.PRNGKey(1), (dim,))\n",
"logits = np.dot(data, true_coefs)\n",
"labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(2))\n",
"\n",
"def model(labels):\n",
" coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))\n",
" logits = np.dot(data, coefs)\n",
" return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"sample: 100%|██████████| 2000/2000 [02:11<00:00, 15.17it/s]\n"
]
},
{
"data": {
"text/plain": [
"DeviceArray(127698.93, dtype=float32)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tic = time.time()\n",
"mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=1000, chain_method='vectorized')\n",
"mcmc.run(random.PRNGKey(0), labels, collect_fields=('z', 'num_steps'), collect_warmup=True)\n",
"toc = time.time()\n",
"# leapfrogs / s\n",
"mcmc.get_samples()[1].sum() / (toc - tic)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment