Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save tdhopper/20a6ac9ff068cc1ab032db7a0634a424 to your computer and use it in GitHub Desktop.
Save tdhopper/20a6ac9ff068cc1ab032db7a0634a424 to your computer and use it in GitHub Desktop.
Metropolis-Hastings runs 450x faster. pymc 3.0.rc2 and Python 3.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy as np \n",
"import pymc3 as pm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"'3.0.rc2'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm.__version__"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"success = np.array([4, 125, 592, 2, 810, 0, 0, 1, 0, 104, 0, 0, 12, 70, 1, 64, 0, 0, 50, 1, 0, 0, 0, 0, 0, 0, 6, 3, 1, 0, 2, 0, 0, 2, 1, 2, 0, 1, 0, 2, 2, 1, 1, 0, 27, 1, 30, 0, 4, 0, 5, 0, 0, 0, 0, 0, 0, 0, 4, 2, 8, 0, 0, 0, 0, 0, 0, 1, 0, 0, 4, 2, 0, 0, 5, 0, 62, 0, 0, 0, 0, 13, 15, 0, 0, 8, 0, 0, 6, 17, 3, 0, 0, 0, 0, 3, 129, 37, 3, 5, 0, 9, 57, 0, 1, 0, 0, 0, 20, 2, 0, 5, 0, 105, 5, 30, 0, 0, 1, 92, 7, 8, 0, 4, 0, 0, 12, 0, 14, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 7, 14, 0, 217, 1, 1, 0, 0, 21, 0, 1, 2, 0, 0, 157, 0, 0, 0, 1, 0, 7, 0, 9, 4, 15, 1, 8, 0, 0, 0, 0, 41, 0, 2, 0, 0, 2, 164, 108, 9, 13, 0, 0, 0, 52, 41, 2, 0, 0, 2, 10, 0, 4, 0, 1, 7, 4, 1, 0, 797, 1, 0, 1, 0, 1, 2, 0, 0, 0, 20, 2804, 0, 0, 0, 37, 0, 1, 0, 0, 0, 1, 0, 0, 0, 28, 0, 5, 0, 4, 376, 0, 2199, 1, 0, 0, 0, 1, 0, 1, 2, 0, 0, 0, 0, 3, 0, 0, 1, 0, 0, 0, 0, 80, 0, 5, 27, 1, 65, 0, 0, 3, 3, 0, 0, 0, 1, 0, 0, 11, 2, 0, 3, 11, 1, 0, 413, 3, 11, 0, 0, 0, 506, 0, 19, 0, 0, 1, 2, 468, 0, 0, 0, 0, 10, 0, 0, 5, 0, 0, 0, 103, 21, 14, 5, 0, 2, 0, 8, 0, 0, 25, 19, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 19, 0, 0, 21, 1, 0, 8, 9, 8, 16, 0, 0, 0, 0, 44, 3, 1, 4, 0, 0, 18, 4, 1, 0, 0, 0, 1, 3, 27, 0, 0, 1, 1, 0, 10, 1, 1, 15, 0, 0, 9, 0, 0, 0, 0, 0, 7, 0, 0, 1, 0, 2, 10, 0, 0, 0, 32, 0, 0, 2, 11, 1, 6, 104, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 2, 10, 0, 0, 2, 9, 12, 2, 0, 0, 0, 6, 0, 0, 1, 1, 0, 0, 1, 0, 2, 0, 0, 0, 0, 12, 0, 1, 0, 36, 0, 164, 0, 0, 5, 6, 0, 1, 13, 0, 6, 12, 1, 0, 1, 0, 1445, 0, 0, 30, 0, 37, 1, 0, 1, 3, 0, 0, 4, 0, 0, 0, 0, 0, 2, 0, 3, 0, 0, 9, 1, 0, 1, 0, 2, 0, 0, 0, 1307, 0, 0, 2, 0, 1, 4, 2, 322, 0, 0, 0, 2, 3, 0, 2, 0, 0, 0, 24, 239, 0, 0, 23, 0, 0, 0, 2, 6, 2, 18, 225, 5, 0, 3, 6, 183, 0, 0, 2, 43, 0, 0, 0, 1, 0, 1, 8, 0, 3, 0, 0, 0, 51, 0, 1, 5, 7, 5, 2, 4, 0, 4, 1, 19, 76, 0, 0, 0, 0, 10, 5, 7, 3, 2, 2, 7, 0, 1, 1, 1, 0, 11, 12, 0, 0, 167, 4, 3, 0, 0, 0, 1, 0, 0, 36, 0, 0, 147, 1, 6, 0, 0, 0, 2, 81, 0, 0, 0, 0, 0, 0, 1, 6, 1, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 1, 0, 2, 50, 2, 1, 109, 0, 25, 0, 6, 175, 1, 1, 0, 1, 2, 0, 0, 360, 9, 17, 157, 0, 0, 0, 19, 0, 76, 0, 1, 23, 0, 0, 0, 11, 23, 0, 2, 0, 0, 9, 1, 0, 643, 2, 801, 33, 0, 0, 0, 2, 0, 10, 403, 0, 0, 1, 5, 0, 1, 0, 0, 11, 0, 5, 0, 0, 4, 1, 6, 37, 0, 0, 0, 0, 31, 0, 0, 9, 75, 0, 0, 1, 65, 8, 93, 1, 3, 1, 0, 0, 958, 2, 0, 0, 0, 0, 1, 0, 0, 1, 2, 0, 0, 287, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 3, 0, 4, 19, 4, 0, 0, 3, 0, 6, 0, 9, 262, 18, 4, 0, 0, 0, 11, 259, 1, 0, 5, 0, 0, 2, 0, 0, 0, 0, 0, 71, 0, 0, 6, 0, 0, 50, 0, 1, 1, 0, 0, 0, 0, 0, 9, 0, 4, 0, 0, 21, 17, 1, 1, 5, 0, 1, 1, 2, 0, 0, 0, 1, 3, 0, 5, 0])\n",
"attempt = np.array([1860, 12802, 11765, 7, 13633, 1, 266, 255, 7, 4479, 1, 2, 924, 10871, 8, 351, 1, 7, 12805, 72, 94, 13, 2, 12, 31, 4, 732, 5, 3, 14, 62, 158, 24, 528, 369, 7, 12, 176, 3, 9, 44, 261, 349, 12, 791, 181, 281, 432, 1913, 814, 33, 166, 1, 3, 1366, 65, 18, 43, 2013, 60, 430, 8, 4, 2, 116, 39, 6, 998, 87, 136, 1542, 225, 1, 5, 15, 1, 7358, 31, 399, 2, 38, 848, 287, 9, 1, 25, 81, 5, 1907, 400, 12, 14060, 1, 2, 161, 1505, 4087, 9217, 17, 561, 109, 1031, 205, 7, 64, 41, 133, 72, 90958, 31, 59, 21, 11, 2721, 1429, 217, 13, 7, 417, 2352, 213, 472, 212, 332, 11, 1, 472, 4, 3316, 761, 1, 53, 142, 48, 869, 16, 4, 228, 33, 26, 1779, 31, 1343, 477, 587, 189, 12, 2641, 9, 1, 1186, 40, 49, 14117, 84, 5, 26, 51, 14, 301, 14, 1956, 662, 3133, 5, 23, 240, 42, 632, 1922, 5761, 53, 1980, 35, 124, 2, 577, 11636, 913, 193, 1020, 595, 1089, 1103, 3479, 432, 2, 6, 2192, 3128, 53, 1089, 245, 982, 15, 162, 465, 52, 32860, 1779, 45, 2454, 5, 236, 928, 251, 26, 76, 5261, 16164, 31, 186, 1, 108, 8, 34, 73, 9, 3, 258, 63, 1, 2, 280, 59, 1149, 3, 130, 112456, 20, 204290, 1575, 3, 206, 26, 787, 131, 68, 664, 37, 116, 12, 8, 160, 4, 2, 93, 7, 1, 21979, 7, 8848, 326, 1629, 905, 2409, 164, 25, 2, 18, 1043, 1, 8, 212, 19, 10, 105, 301, 309, 18, 7, 24, 1, 25, 816, 318, 111, 30, 231, 209, 7758, 34, 29, 3, 1313, 23, 183, 6197, 113, 34, 407, 389, 232, 472, 5, 216, 2474, 17, 62, 51115, 381, 157, 1981, 1, 2842, 32, 733, 17, 42, 49, 474, 16, 68, 2, 3, 257, 10926, 5841, 29, 2, 43, 121, 1845, 8, 192, 2141, 388, 14, 363, 345, 1133, 6375, 22, 19, 13, 12, 71, 92, 85, 161, 242, 209, 15281, 271, 166, 3, 8, 3422, 1016, 82, 43211, 14, 127, 20, 20, 3, 14, 91, 470, 239, 469, 1, 61, 101, 5, 66, 14, 3, 714, 36, 14, 451, 1, 937, 1168, 30, 4, 265, 8289, 9, 14109, 768, 75, 953, 629, 7270, 8, 1, 2, 32, 5, 6, 3, 4, 56, 2, 101, 13, 78, 58, 2059, 30770, 102, 87, 349, 1, 1527, 11, 20, 8, 116, 587, 9, 301, 361, 8, 184, 1, 28, 104, 17, 1021, 26, 68, 29136, 5291, 49, 358, 57, 2, 157, 344, 1, 64, 47, 319, 286, 343, 47, 78, 146, 291, 33238, 7, 1, 1448, 9, 12033, 12, 55, 9, 4, 2, 7, 29, 534, 16, 7888, 17, 32, 438, 6, 124, 47, 18, 450, 271, 12, 36, 231, 417, 7, 2, 933, 16967, 134, 21, 8529, 6, 381, 50, 595, 14568, 1, 1, 42, 140, 39, 6, 967, 2, 1, 1, 1589, 7862, 4, 31, 3020, 63, 255, 25, 76, 70, 1817, 24, 3884, 16, 1, 1735, 897, 642, 6, 73, 1972, 143, 6, 340, 3, 131, 3, 7, 1033, 3, 844, 15, 471, 26, 22184, 63, 156, 3395, 26, 632, 1320, 1070, 24, 406, 1, 13013, 1722, 138, 31, 2, 6, 148, 176, 2149, 1342, 984, 129, 4367, 4, 2, 29, 6, 25, 1373, 171, 201, 56, 9559, 7226, 93, 1, 189, 104, 2, 544, 5, 4915, 48, 2, 20015, 145, 14, 722, 409, 156, 670, 3153, 62, 15, 12, 21, 44, 33, 1, 68, 501, 164, 1, 50, 1, 170, 117, 39, 111, 1, 2, 1020, 2, 125, 5, 94, 1300, 56, 2, 2207, 1569, 295, 62, 414, 5013, 83, 186, 32, 2, 202, 39, 1, 138238, 553, 1167, 650, 1, 84, 1568, 4598, 45, 3114, 2, 2195, 4726, 79, 15, 300, 763, 3593, 7, 375, 108, 5, 56, 3678, 6, 2276, 57, 4172, 991, 7, 48, 1, 1022, 17, 52, 1414, 45, 61, 87, 312, 516, 518, 5, 437, 677, 128, 22, 9, 97, 71, 7, 18, 344, 1, 85, 465, 33, 963, 7, 45, 1442, 3304, 2, 11, 66, 9141, 3475, 196, 91, 379, 17, 624, 52, 6851, 4189, 12, 25, 33, 42, 11, 1, 44, 3451, 48, 1, 1, 5912, 68, 224, 16, 3, 41, 355, 5, 34, 634, 1, 49, 138, 1626, 7, 1218, 511, 60, 203, 299, 125, 752, 20, 7, 43, 30117, 44, 22, 8, 2, 1, 513, 7238, 2, 58, 47, 136, 3, 19, 11, 16, 23, 7, 8, 155, 154, 65, 1456, 165, 82, 225, 64, 25, 39, 6, 17, 19, 1, 23, 195, 8, 3714, 22, 11, 684, 1164, 398, 2, 1031, 212, 6, 2, 238, 25, 1, 15, 307, 180, 4, 6, 332])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def beta_bernoulli(attempt, success, step='metropolis', niter=10000, mu_hypers=(1, 9), kappa_hypers=(1, 1)):\n",
" with pm.Model() as model:\n",
" mu = pm.Beta('mu', *mu_hypers)\n",
" kappa = pm.Gamma('kappa', *kappa_hypers) \n",
" p = pm.Beta('theta', mu * kappa, (1 - mu) * kappa, shape=len(attempt))\n",
" y = pm.Binomial('y', n=attempt,\n",
" p=p,\n",
" observed=success)\n",
" if step.lower() == 'metropolis':\n",
" print(\"Sample with metropolis\")\n",
" return pm.sample(niter, step=pm.Metropolis(), progressbar=True, \n",
" start=pm.find_MAP())\n",
" elif step.lower() == 'nuts':\n",
" print(\"Sample with NUTS\")\n",
" return pm.sample(niter, step=pm.NUTS(), progressbar=True, \n",
" start=pm.find_MAP())\n",
" elif step.lower() == 'nuts_advi':\n",
" print(\"Sample with NUTS with ADVI Initilization\")\n",
" start, _, _ = pm.variational.advi(n=200000)\n",
" return pm.sample(niter, step=pm.NUTS(), progressbar=True, \n",
" start=start)\n",
" else:\n",
" raise Exception(\"no sampler specified\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Applied logodds-transform to mu and added transformed mu_logodds_ to model.\n",
"Applied log-transform to kappa and added transformed kappa_log_ to model.\n",
"Applied logodds-transform to theta and added transformed theta_logodds_ to model.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample with metropolis\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10000/10000 [00:07<00:00, 1401.86it/s]\n"
]
},
{
"data": {
"text/plain": [
"<MultiTrace: 1 chains, 10000 iterations, 6 variables>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"beta_bernoulli(attempt, success, 'metropolis')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Applied logodds-transform to mu and added transformed mu_logodds_ to model.\n",
"Applied log-transform to kappa and added transformed kappa_log_ to model.\n",
"Applied logodds-transform to theta and added transformed theta_logodds_ to model.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample with NUTS\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 50/50 [00:27<00:00, 2.92it/s]\n"
]
},
{
"data": {
"text/plain": [
"<MultiTrace: 1 chains, 50 iterations, 6 variables>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"beta_bernoulli(attempt, success, 'nuts', niter=50)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Applied logodds-transform to mu and added transformed mu_logodds_ to model.\n",
"Applied log-transform to kappa and added transformed kappa_log_ to model.\n",
"Applied logodds-transform to theta and added transformed theta_logodds_ to model.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample with NUTS with ADVI Initilization\n",
"Iteration 0 [0%]: ELBO = -147658.03\n",
"Iteration 20000 [10%]: Average ELBO = -28632.13\n",
"Iteration 40000 [20%]: Average ELBO = -2059.87\n",
"Iteration 60000 [30%]: Average ELBO = -2114.85\n",
"Iteration 80000 [40%]: Average ELBO = -2188.04\n",
"Iteration 100000 [50%]: Average ELBO = -2257.17\n",
"Iteration 120000 [60%]: Average ELBO = -2322.03\n",
"Iteration 140000 [70%]: Average ELBO = -2384.65\n",
"Iteration 160000 [80%]: Average ELBO = -2448.75\n",
"Iteration 180000 [90%]: Average ELBO = -2497.9\n",
"Finished [100%]: Average ELBO = -2558.98\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 50/50 [00:27<00:00, 1.95it/s]\n"
]
},
{
"data": {
"text/plain": [
"<MultiTrace: 1 chains, 50 iterations, 6 variables>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"beta_bernoulli(attempt, success, 'nuts_advi', niter=50)"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda env:ds-757-captcha-sampling]",
"language": "python",
"name": "conda-env-ds-757-captcha-sampling-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
@tdhopper
Copy link
Author

tdhopper commented Nov 8, 2016

For posterity, here's how you can speed up NUTS:
http://til.tdhopper.com/notes/speeding-up-pymc3-nuts-sampler

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment