Last active
November 8, 2016 20:49
-
-
Save tdhopper/20a6ac9ff068cc1ab032db7a0634a424 to your computer and use it in GitHub Desktop.
Metropolis-Hastings runs 450x faster. pymc 3.0.rc2 and Python 3.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np \n", | |
"import pymc3 as pm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'3.0.rc2'" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pm.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"success = np.array([4, 125, 592, 2, 810, 0, 0, 1, 0, 104, 0, 0, 12, 70, 1, 64, 0, 0, 50, 1, 0, 0, 0, 0, 0, 0, 6, 3, 1, 0, 2, 0, 0, 2, 1, 2, 0, 1, 0, 2, 2, 1, 1, 0, 27, 1, 30, 0, 4, 0, 5, 0, 0, 0, 0, 0, 0, 0, 4, 2, 8, 0, 0, 0, 0, 0, 0, 1, 0, 0, 4, 2, 0, 0, 5, 0, 62, 0, 0, 0, 0, 13, 15, 0, 0, 8, 0, 0, 6, 17, 3, 0, 0, 0, 0, 3, 129, 37, 3, 5, 0, 9, 57, 0, 1, 0, 0, 0, 20, 2, 0, 5, 0, 105, 5, 30, 0, 0, 1, 92, 7, 8, 0, 4, 0, 0, 12, 0, 14, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 7, 14, 0, 217, 1, 1, 0, 0, 21, 0, 1, 2, 0, 0, 157, 0, 0, 0, 1, 0, 7, 0, 9, 4, 15, 1, 8, 0, 0, 0, 0, 41, 0, 2, 0, 0, 2, 164, 108, 9, 13, 0, 0, 0, 52, 41, 2, 0, 0, 2, 10, 0, 4, 0, 1, 7, 4, 1, 0, 797, 1, 0, 1, 0, 1, 2, 0, 0, 0, 20, 2804, 0, 0, 0, 37, 0, 1, 0, 0, 0, 1, 0, 0, 0, 28, 0, 5, 0, 4, 376, 0, 2199, 1, 0, 0, 0, 1, 0, 1, 2, 0, 0, 0, 0, 3, 0, 0, 1, 0, 0, 0, 0, 80, 0, 5, 27, 1, 65, 0, 0, 3, 3, 0, 0, 0, 1, 0, 0, 11, 2, 0, 3, 11, 1, 0, 413, 3, 11, 0, 0, 0, 506, 0, 19, 0, 0, 1, 2, 468, 0, 0, 0, 0, 10, 0, 0, 5, 0, 0, 0, 103, 21, 14, 5, 0, 2, 0, 8, 0, 0, 25, 19, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 19, 0, 0, 21, 1, 0, 8, 9, 8, 16, 0, 0, 0, 0, 44, 3, 1, 4, 0, 0, 18, 4, 1, 0, 0, 0, 1, 3, 27, 0, 0, 1, 1, 0, 10, 1, 1, 15, 0, 0, 9, 0, 0, 0, 0, 0, 7, 0, 0, 1, 0, 2, 10, 0, 0, 0, 32, 0, 0, 2, 11, 1, 6, 104, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 2, 10, 0, 0, 2, 9, 12, 2, 0, 0, 0, 6, 0, 0, 1, 1, 0, 0, 1, 0, 2, 0, 0, 0, 0, 12, 0, 1, 0, 36, 0, 164, 0, 0, 5, 6, 0, 1, 13, 0, 6, 12, 1, 0, 1, 0, 1445, 0, 0, 30, 0, 37, 1, 0, 1, 3, 0, 0, 4, 0, 0, 0, 0, 0, 2, 0, 3, 0, 0, 9, 1, 0, 1, 0, 2, 0, 0, 0, 1307, 0, 0, 2, 0, 1, 4, 2, 322, 0, 0, 0, 2, 3, 0, 2, 0, 0, 0, 24, 239, 0, 0, 23, 0, 0, 0, 2, 6, 2, 18, 225, 5, 0, 3, 6, 183, 0, 0, 2, 43, 0, 0, 0, 1, 0, 1, 8, 0, 3, 0, 0, 0, 51, 0, 1, 5, 7, 5, 2, 4, 0, 4, 1, 19, 76, 0, 0, 0, 0, 10, 5, 7, 3, 2, 2, 7, 0, 1, 1, 1, 0, 11, 12, 0, 0, 167, 4, 3, 0, 0, 0, 1, 0, 0, 36, 0, 0, 147, 1, 6, 0, 0, 0, 2, 81, 0, 0, 0, 0, 0, 0, 1, 6, 1, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 1, 0, 2, 50, 2, 1, 109, 0, 25, 0, 6, 175, 1, 1, 0, 1, 2, 0, 0, 360, 9, 17, 157, 0, 0, 0, 19, 0, 76, 0, 1, 23, 0, 0, 0, 11, 23, 0, 2, 0, 0, 9, 1, 0, 643, 2, 801, 33, 0, 0, 0, 2, 0, 10, 403, 0, 0, 1, 5, 0, 1, 0, 0, 11, 0, 5, 0, 0, 4, 1, 6, 37, 0, 0, 0, 0, 31, 0, 0, 9, 75, 0, 0, 1, 65, 8, 93, 1, 3, 1, 0, 0, 958, 2, 0, 0, 0, 0, 1, 0, 0, 1, 2, 0, 0, 287, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 3, 0, 4, 19, 4, 0, 0, 3, 0, 6, 0, 9, 262, 18, 4, 0, 0, 0, 11, 259, 1, 0, 5, 0, 0, 2, 0, 0, 0, 0, 0, 71, 0, 0, 6, 0, 0, 50, 0, 1, 1, 0, 0, 0, 0, 0, 9, 0, 4, 0, 0, 21, 17, 1, 1, 5, 0, 1, 1, 2, 0, 0, 0, 1, 3, 0, 5, 0])\n", | |
"attempt = np.array([1860, 12802, 11765, 7, 13633, 1, 266, 255, 7, 4479, 1, 2, 924, 10871, 8, 351, 1, 7, 12805, 72, 94, 13, 2, 12, 31, 4, 732, 5, 3, 14, 62, 158, 24, 528, 369, 7, 12, 176, 3, 9, 44, 261, 349, 12, 791, 181, 281, 432, 1913, 814, 33, 166, 1, 3, 1366, 65, 18, 43, 2013, 60, 430, 8, 4, 2, 116, 39, 6, 998, 87, 136, 1542, 225, 1, 5, 15, 1, 7358, 31, 399, 2, 38, 848, 287, 9, 1, 25, 81, 5, 1907, 400, 12, 14060, 1, 2, 161, 1505, 4087, 9217, 17, 561, 109, 1031, 205, 7, 64, 41, 133, 72, 90958, 31, 59, 21, 11, 2721, 1429, 217, 13, 7, 417, 2352, 213, 472, 212, 332, 11, 1, 472, 4, 3316, 761, 1, 53, 142, 48, 869, 16, 4, 228, 33, 26, 1779, 31, 1343, 477, 587, 189, 12, 2641, 9, 1, 1186, 40, 49, 14117, 84, 5, 26, 51, 14, 301, 14, 1956, 662, 3133, 5, 23, 240, 42, 632, 1922, 5761, 53, 1980, 35, 124, 2, 577, 11636, 913, 193, 1020, 595, 1089, 1103, 3479, 432, 2, 6, 2192, 3128, 53, 1089, 245, 982, 15, 162, 465, 52, 32860, 1779, 45, 2454, 5, 236, 928, 251, 26, 76, 5261, 16164, 31, 186, 1, 108, 8, 34, 73, 9, 3, 258, 63, 1, 2, 280, 59, 1149, 3, 130, 112456, 20, 204290, 1575, 3, 206, 26, 787, 131, 68, 664, 37, 116, 12, 8, 160, 4, 2, 93, 7, 1, 21979, 7, 8848, 326, 1629, 905, 2409, 164, 25, 2, 18, 1043, 1, 8, 212, 19, 10, 105, 301, 309, 18, 7, 24, 1, 25, 816, 318, 111, 30, 231, 209, 7758, 34, 29, 3, 1313, 23, 183, 6197, 113, 34, 407, 389, 232, 472, 5, 216, 2474, 17, 62, 51115, 381, 157, 1981, 1, 2842, 32, 733, 17, 42, 49, 474, 16, 68, 2, 3, 257, 10926, 5841, 29, 2, 43, 121, 1845, 8, 192, 2141, 388, 14, 363, 345, 1133, 6375, 22, 19, 13, 12, 71, 92, 85, 161, 242, 209, 15281, 271, 166, 3, 8, 3422, 1016, 82, 43211, 14, 127, 20, 20, 3, 14, 91, 470, 239, 469, 1, 61, 101, 5, 66, 14, 3, 714, 36, 14, 451, 1, 937, 1168, 30, 4, 265, 8289, 9, 14109, 768, 75, 953, 629, 7270, 8, 1, 2, 32, 5, 6, 3, 4, 56, 2, 101, 13, 78, 58, 2059, 30770, 102, 87, 349, 1, 1527, 11, 20, 8, 116, 587, 9, 301, 361, 8, 184, 1, 28, 104, 17, 1021, 26, 68, 29136, 5291, 49, 358, 57, 2, 157, 344, 1, 64, 47, 319, 286, 343, 47, 78, 146, 291, 33238, 7, 1, 1448, 9, 12033, 12, 55, 9, 4, 2, 7, 29, 534, 16, 7888, 17, 32, 438, 6, 124, 47, 18, 450, 271, 12, 36, 231, 417, 7, 2, 933, 16967, 134, 21, 8529, 6, 381, 50, 595, 14568, 1, 1, 42, 140, 39, 6, 967, 2, 1, 1, 1589, 7862, 4, 31, 3020, 63, 255, 25, 76, 70, 1817, 24, 3884, 16, 1, 1735, 897, 642, 6, 73, 1972, 143, 6, 340, 3, 131, 3, 7, 1033, 3, 844, 15, 471, 26, 22184, 63, 156, 3395, 26, 632, 1320, 1070, 24, 406, 1, 13013, 1722, 138, 31, 2, 6, 148, 176, 2149, 1342, 984, 129, 4367, 4, 2, 29, 6, 25, 1373, 171, 201, 56, 9559, 7226, 93, 1, 189, 104, 2, 544, 5, 4915, 48, 2, 20015, 145, 14, 722, 409, 156, 670, 3153, 62, 15, 12, 21, 44, 33, 1, 68, 501, 164, 1, 50, 1, 170, 117, 39, 111, 1, 2, 1020, 2, 125, 5, 94, 1300, 56, 2, 2207, 1569, 295, 62, 414, 5013, 83, 186, 32, 2, 202, 39, 1, 138238, 553, 1167, 650, 1, 84, 1568, 4598, 45, 3114, 2, 2195, 4726, 79, 15, 300, 763, 3593, 7, 375, 108, 5, 56, 3678, 6, 2276, 57, 4172, 991, 7, 48, 1, 1022, 17, 52, 1414, 45, 61, 87, 312, 516, 518, 5, 437, 677, 128, 22, 9, 97, 71, 7, 18, 344, 1, 85, 465, 33, 963, 7, 45, 1442, 3304, 2, 11, 66, 9141, 3475, 196, 91, 379, 17, 624, 52, 6851, 4189, 12, 25, 33, 42, 11, 1, 44, 3451, 48, 1, 1, 5912, 68, 224, 16, 3, 41, 355, 5, 34, 634, 1, 49, 138, 1626, 7, 1218, 511, 60, 203, 299, 125, 752, 20, 7, 43, 30117, 44, 22, 8, 2, 1, 513, 7238, 2, 58, 47, 136, 3, 19, 11, 16, 23, 7, 8, 155, 154, 65, 1456, 165, 82, 225, 64, 25, 39, 6, 17, 19, 1, 23, 195, 8, 3714, 22, 11, 684, 1164, 398, 2, 1031, 212, 6, 2, 238, 25, 1, 15, 307, 180, 4, 6, 332])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def beta_bernoulli(attempt, success, step='metropolis', niter=10000, mu_hypers=(1, 9), kappa_hypers=(1, 1)):\n", | |
" with pm.Model() as model:\n", | |
" mu = pm.Beta('mu', *mu_hypers)\n", | |
" kappa = pm.Gamma('kappa', *kappa_hypers) \n", | |
" p = pm.Beta('theta', mu * kappa, (1 - mu) * kappa, shape=len(attempt))\n", | |
" y = pm.Binomial('y', n=attempt,\n", | |
" p=p,\n", | |
" observed=success)\n", | |
" if step.lower() == 'metropolis':\n", | |
" print(\"Sample with metropolis\")\n", | |
" return pm.sample(niter, step=pm.Metropolis(), progressbar=True, \n", | |
" start=pm.find_MAP())\n", | |
" elif step.lower() == 'nuts':\n", | |
" print(\"Sample with NUTS\")\n", | |
" return pm.sample(niter, step=pm.NUTS(), progressbar=True, \n", | |
" start=pm.find_MAP())\n", | |
" elif step.lower() == 'nuts_advi':\n", | |
" print(\"Sample with NUTS with ADVI Initilization\")\n", | |
" start, _, _ = pm.variational.advi(n=200000)\n", | |
" return pm.sample(niter, step=pm.NUTS(), progressbar=True, \n", | |
" start=start)\n", | |
" else:\n", | |
" raise Exception(\"no sampler specified\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Applied logodds-transform to mu and added transformed mu_logodds_ to model.\n", | |
"Applied log-transform to kappa and added transformed kappa_log_ to model.\n", | |
"Applied logodds-transform to theta and added transformed theta_logodds_ to model.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Sample with metropolis\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 10000/10000 [00:07<00:00, 1401.86it/s]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<MultiTrace: 1 chains, 10000 iterations, 6 variables>" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"beta_bernoulli(attempt, success, 'metropolis')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Applied logodds-transform to mu and added transformed mu_logodds_ to model.\n", | |
"Applied log-transform to kappa and added transformed kappa_log_ to model.\n", | |
"Applied logodds-transform to theta and added transformed theta_logodds_ to model.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Sample with NUTS\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 50/50 [00:27<00:00, 2.92it/s]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<MultiTrace: 1 chains, 50 iterations, 6 variables>" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"beta_bernoulli(attempt, success, 'nuts', niter=50)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Applied logodds-transform to mu and added transformed mu_logodds_ to model.\n", | |
"Applied log-transform to kappa and added transformed kappa_log_ to model.\n", | |
"Applied logodds-transform to theta and added transformed theta_logodds_ to model.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Sample with NUTS with ADVI Initilization\n", | |
"Iteration 0 [0%]: ELBO = -147658.03\n", | |
"Iteration 20000 [10%]: Average ELBO = -28632.13\n", | |
"Iteration 40000 [20%]: Average ELBO = -2059.87\n", | |
"Iteration 60000 [30%]: Average ELBO = -2114.85\n", | |
"Iteration 80000 [40%]: Average ELBO = -2188.04\n", | |
"Iteration 100000 [50%]: Average ELBO = -2257.17\n", | |
"Iteration 120000 [60%]: Average ELBO = -2322.03\n", | |
"Iteration 140000 [70%]: Average ELBO = -2384.65\n", | |
"Iteration 160000 [80%]: Average ELBO = -2448.75\n", | |
"Iteration 180000 [90%]: Average ELBO = -2497.9\n", | |
"Finished [100%]: Average ELBO = -2558.98\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 50/50 [00:27<00:00, 1.95it/s]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<MultiTrace: 1 chains, 50 iterations, 6 variables>" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"beta_bernoulli(attempt, success, 'nuts_advi', niter=50)" | |
] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [conda env:ds-757-captcha-sampling]", | |
"language": "python", | |
"name": "conda-env-ds-757-captcha-sampling-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
For posterity, here's how you can speed up NUTS:
http://til.tdhopper.com/notes/speeding-up-pymc3-nuts-sampler