Created
November 26, 2018 16:28
-
-
Save sadatnfs/d5004f488ba00371333770059ab99776 to your computer and use it in GitHub Desktop.
Attempting to make PyMC4 work
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"from tensorflow_probability import edward2 as ed\n", | |
"from tensorflow_probability import distributions as tfd\n", | |
"import pymc4 as pm4" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Create fake data\n", | |
"alpha_raw = 0.1\n", | |
"beta_raw = 0.5\n", | |
"sigma_raw = 0.25\n", | |
"N = 1500\n", | |
"x = np.random.normal(size = N) \n", | |
"y = alpha_raw + beta_raw * x + np.random.normal(scale = sigma_raw**0.5, size = N)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Reset graph\n", | |
"tf.reset_default_graph()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<pymc4.model.base.Model at 0x2abef427fc88>" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# PyMC4 model initialize\n", | |
"pymc4_ols = pm4.Model(X = x, Y = y)\n", | |
"pymc4_ols.observe(X = pymc4_ols.cfg.X, Y = pymc4_ols.cfg.Y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@pymc4_ols.define\n", | |
"def process(cfg):\n", | |
" alpha = ed.Normal(\n", | |
" loc=0., scale=5., name=\"alpha\")\n", | |
" beta = ed.Normal(\n", | |
" loc=0., scale=5., name = \"beta\")\n", | |
" sigma = tf.exp(ed.Gamma(1., 1., name = 'sigma'))\n", | |
" Yhat = alpha + beta*cfg.X\n", | |
" Y = ed.Normal(\n", | |
" loc=Yhat,\n", | |
" scale=sigma,\n", | |
" name=\"Y\")\n", | |
" return Y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'X': array([-0.83580625, -0.16027781, -0.20688124, ..., 0.01717773,\n", | |
" -1.28557357, 0.07603539]),\n", | |
" 'Y': array([-0.33343107, 0.17018252, -0.0079468 , ..., -0.5135459 ,\n", | |
" 0.10321138, 0.92192937])}" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pymc4_ols.observed" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"OrderedDict([('alpha',\n", | |
" VariableDescription(Dist=<class 'tensorflow_probability.python.distributions.normal.Normal'>, shape=TensorShape([]), rv=<ed.RandomVariable 'alpha/' shape=() dtype=float32>)),\n", | |
" ('beta',\n", | |
" VariableDescription(Dist=<class 'tensorflow_probability.python.distributions.normal.Normal'>, shape=TensorShape([]), rv=<ed.RandomVariable 'beta/' shape=() dtype=float32>)),\n", | |
" ('sigma',\n", | |
" VariableDescription(Dist=<class 'tensorflow_probability.python.distributions.gamma.Gamma'>, shape=TensorShape([]), rv=<ed.RandomVariable 'sigma/' shape=() dtype=float32>))])" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pymc4_ols.unobserved" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.6/site-packages/numpy-1.16.0.dev0+b47ed76-py3.6-linux-x86_64.egg/numpy/lib/type_check.py:549: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n", | |
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Acceptance rate: 0.0\n" | |
] | |
} | |
], | |
"source": [ | |
"pymc4_trace = pm4.sample(pymc4_ols, num_burnin_steps=1000, num_results=20)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'alpha': array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,\n", | |
" 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32),\n", | |
" 'beta': array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,\n", | |
" 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32),\n", | |
" 'sigma': array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,\n", | |
" 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32)}" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pymc4_trace" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### However, the 8 schools example works off the box..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Data of the Eight Schools Model\n", | |
"J = 8\n", | |
"y = np.array([28., 8., -3., 7., -1., 1., 18., 12.])\n", | |
"sigma = np.array([15., 10., 16., 11., 9., 11., 10., 18.])\n", | |
"# tau = 25." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = pm4.Model(num_schools=J, y=y, sigma=sigma )\n", | |
"@model.define\n", | |
"def process(cfg):\n", | |
" mu = ed.Normal(loc=0., scale=5., name=\"mu\") # `mu` above\n", | |
" # Due to the lack of HalfCauchy distribution.\n", | |
" log_tau = ed.Normal(\n", | |
" loc=5., scale=1., name=\"log_tau\") # `log(tau)` above\n", | |
" theta_prime = ed.Normal(\n", | |
" loc=tf.zeros(cfg.num_schools),\n", | |
" scale=tf.ones(cfg.num_schools),\n", | |
" name=\"theta_prime\") # `theta_prime` above\n", | |
" theta = mu + tf.exp(\n", | |
" log_tau) * theta_prime # `theta` above\n", | |
" y = ed.Normal(\n", | |
" loc=theta,\n", | |
" scale=np.float32(cfg.sigma),\n", | |
" name=\"y\") # `y` above\n", | |
" \n", | |
" return y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<pymc4.model.base.Model at 0x2ac07cc1cdd8>" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.observe(y = model.cfg.y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.6/site-packages/numpy-1.16.0.dev0+b47ed76-py3.6-linux-x86_64.egg/numpy/lib/type_check.py:549: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n", | |
" 'a.item() instead', DeprecationWarning, stacklevel=1)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Acceptance rate: 0.6102\n" | |
] | |
} | |
], | |
"source": [ | |
"trace = pm4.sample(model, num_burnin_steps=1000, num_results=5000)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'mu': array([ 1.2623473, 1.2623473, 1.2623473, ..., 10.174107 , 10.101565 ,\n", | |
" 11.02169 ], dtype=float32),\n", | |
" 'log_tau': array([2.798712 , 2.798712 , 2.798712 , ..., 2.237354 , 2.206139 ,\n", | |
" 1.4436612], dtype=float32),\n", | |
" 'theta_prime': array([[ 1.4495487 , 0.13190484, -0.55798876, ..., -0.39001518,\n", | |
" 0.49999905, -0.14500381],\n", | |
" [ 1.4495487 , 0.13190484, -0.55798876, ..., -0.39001518,\n", | |
" 0.49999905, -0.14500381],\n", | |
" [ 1.4495487 , 0.13190484, -0.55798876, ..., -0.39001518,\n", | |
" 0.49999905, -0.14500381],\n", | |
" ...,\n", | |
" [ 2.3089426 , 0.6587325 , 0.04801591, ..., -0.14484467,\n", | |
" 1.0311049 , -0.980293 ],\n", | |
" [ 0.772265 , -1.3196737 , -2.4158447 , ..., 0.49178684,\n", | |
" -0.6607568 , 0.24697188],\n", | |
" [ 0.01715243, -1.6104103 , 0.7611127 , ..., -0.10088314,\n", | |
" 1.5387309 , 0.6444084 ]], dtype=float32)}" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"trace" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment