Last active
August 28, 2018 21:25
-
-
Save sadatnfs/04f1fc5207fa498f715191df3ed9b3c9 to your computer and use it in GitHub Desktop.
Fitting a simple random intercept model on GDP data using TFP (following the radon example on TFP page)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 69, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%capture\n", | |
" import numpy as np\n", | |
" import pandas as pd\n", | |
" import tensorflow as tf\n", | |
" import matplotlib.pyplot as plt\n", | |
" from tensorflow.contrib.distributions import MultivariateNormalTriL\n", | |
" import tensorflow_probability as tfp\n", | |
" import warnings \n", | |
" import statsmodels.api as sm\n", | |
" import statsmodels.formula.api as smf" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 54, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = pd.read_csv('/home/j/Project/IRH/Forecasting/gdp/data/RT_2018_GDP_use.csv')\n", | |
"\n", | |
"## Prep data\n", | |
"data = data[['iso3', 'year', 'ln_gdppc', 'ln_TFR', 'ln_pop']]\n", | |
"data['intercept'] = 1.\n", | |
"data = data.dropna()\n", | |
"\n", | |
"# Remap categories to start from 0 and end at max(category).\n", | |
"data['iso3'] = data['iso3'].astype('category').cat.codes\n", | |
"\n", | |
"## Number of REs\n", | |
"n_res = max(data.iso3) + 1\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"inv_scale_transform = lambda y: np.log(y) # Not using TF here.\n", | |
"fwd_scale_transform = tf.exp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _make_weights_prior(num_counties, dtype):\n", | |
" \"\"\"Returns a `len(iso3)` batch of univariate Normal.\"\"\"\n", | |
" raw_prior_scale = tf.get_variable(\n", | |
" name='raw_prior_scale',\n", | |
" initializer=np.array(inv_scale_transform(1.), dtype=dtype))\n", | |
" return tfp.distributions.Independent(\n", | |
" tfp.distributions.Normal(\n", | |
" loc=tf.zeros(n_res, dtype=dtype),\n", | |
" scale=fwd_scale_transform(raw_prior_scale)),\n", | |
" reinterpreted_batch_ndims=1)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 117, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _make_log_Y_likelihood(random_effect_weights, \n", | |
" ln_TFR,\n", | |
" ln_pop,\n", | |
" iso3,\n", | |
" iso3_int, \n", | |
" init_log_Y_stddev):\n", | |
" raw_likelihood_scale = tf.get_variable(\n", | |
" name='raw_likelihood_scale',\n", | |
" initializer=np.array(\n", | |
" inv_scale_transform(init_log_Y_stddev), dtype=dtype))\n", | |
" fixed_effect_weights = tf.get_variable(\n", | |
" name='fixed_effect_weights', \n", | |
"# initializer=np.array([0., 1., 2.], dtype=dtype)\n", | |
" shape = [3],\n", | |
" )\n", | |
" fixed_effects = fixed_effect_weights[0] + \\\n", | |
" fixed_effect_weights[1] * ln_TFR + \\\n", | |
" fixed_effect_weights[2] * ln_pop\n", | |
" random_effects = tf.gather( \\\n", | |
" random_effect_weights * iso3_int,\n", | |
" indices=tf.to_int32(iso3),\n", | |
" axis=-1)\n", | |
" linear_predictor = fixed_effects + random_effects\n", | |
" return tfp.distributions.Normal(\n", | |
" loc=linear_predictor, \\\n", | |
" scale=fwd_scale_transform(raw_likelihood_scale))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 118, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def joint_log_prob(random_effect_weights, \n", | |
" ln_gdppc, \n", | |
" ln_TFR,\n", | |
" ln_pop,\n", | |
" iso3,\n", | |
" iso3_int, \n", | |
" dtype):\n", | |
" num_counties = len(iso3_int)\n", | |
" rv_weights = make_weights_prior(n_res, dtype)\n", | |
" rv_Y = make_log_Y_likelihood(\n", | |
" random_effect_weights,\n", | |
" ln_TFR,\n", | |
" ln_pop,\n", | |
" iso3,\n", | |
" iso3_int, \n", | |
" init_log_Y_stddev=1.)\n", | |
" return (rv_weights.log_prob(random_effect_weights)\n", | |
" + tf.reduce_sum(rv_Y.log_prob(ln_gdppc), axis=-1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 119, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Specify unnormalized posterior.\n", | |
"def unnormalized_posterior_log_prob(random_effect_weights):\n", | |
" return joint_log_prob(\n", | |
" random_effect_weights = random_effect_weights, \n", | |
" ln_gdppc = dtype(data.ln_gdppc.values), \n", | |
" ln_TFR = dtype(data.ln_TFR.values), \n", | |
" ln_pop = dtype(data.ln_pop.values), \n", | |
" iso3 = np.int32(data.iso3.values),\n", | |
" iso3_int = iso3_int, \n", | |
" dtype = dtype)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 128, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Graph resettor\n", | |
"with warnings.catch_warnings():\n", | |
" warnings.simplefilter('ignore')\n", | |
" tf.reset_default_graph()\n", | |
" try:\n", | |
" sess.close()\n", | |
" except:\n", | |
" pass\n", | |
" sess = tf.InteractiveSession()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 129, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Create TF vars\n", | |
"make_weights_prior = tf.make_template(\n", | |
" name_='make_weights_prior', func_=_make_weights_prior)\n", | |
"\n", | |
"make_log_Y_likelihood = tf.make_template(\n", | |
" name_='make_log_Y_likelihood', func_=_make_log_Y_likelihood)\n", | |
"\n", | |
"dtype = np.float32\n", | |
"iso3_int = data[\n", | |
" ['iso3', 'intercept']].drop_duplicates().values[:, 1]\n", | |
"iso3_int = iso3_int.astype(dtype)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 131, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Set-up E-step.\n", | |
"\n", | |
"# step_size = tf.get_variable(\n", | |
"# 'step_size',\n", | |
"# initializer=np.array(0.2, dtype=dtype),\n", | |
"# trainable=False)\n", | |
"\n", | |
"hmc = tfp.mcmc.HamiltonianMonteCarlo(\n", | |
" target_log_prob_fn=unnormalized_posterior_log_prob,\n", | |
" num_leapfrog_steps=2,\n", | |
" step_size=0.015,\n", | |
"# step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(),\n", | |
" state_gradients_are_stopped=True)\n", | |
"\n", | |
"init_random_weights = tf.placeholder(dtype, shape=[len(iso3_int)])\n", | |
"\n", | |
"posterior_random_weights, kernel_results = tfp.mcmc.sample_chain(\n", | |
" num_results=3,\n", | |
" num_burnin_steps=0,\n", | |
" num_steps_between_results=0,\n", | |
" current_state=init_random_weights,\n", | |
" kernel=hmc)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 132, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Set-up M-step.\n", | |
"loss = -tf.reduce_mean(kernel_results.accepted_results.target_log_prob)\n", | |
"global_step = tf.train.get_or_create_global_step()\n", | |
"learning_rate = tf.train.exponential_decay(\n", | |
" learning_rate=0.1,\n", | |
" global_step=global_step,\n", | |
" decay_steps=250,\n", | |
" decay_rate=0.99)\n", | |
"\n", | |
"optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n", | |
"train_op = optimizer.minimize(loss, global_step=global_step)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 133, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Initialize all variables.\n", | |
"init_op = tf.global_variables_initializer()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 134, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Grab variable handles for diagnostic purposes.\n", | |
"with tf.variable_scope('make_weights_prior', reuse=True):\n", | |
" prior_scale = fwd_scale_transform(tf.get_variable(name='raw_prior_scale', dtype=dtype))\n", | |
"\n", | |
"with tf.variable_scope('make_log_Y_likelihood', reuse=True):\n", | |
" likelihood_scale = fwd_scale_transform(tf.get_variable(name='raw_likelihood_scale', dtype=dtype))\n", | |
" fixed_effect_weights = tf.get_variable(name='fixed_effect_weights', dtype=dtype)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 135, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"init_op.run()\n", | |
"w_ = np.zeros([len(iso3_int)], dtype=dtype)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 137, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"global_step: 1 loss: 402393.500 acceptance:1.0000 prior_scale:1.0000 likelihood_scale:1.0000 fixed_effect_weights:[0.86661494 0.40798506 0.01009426]\n", | |
"global_step: 101 loss: 16923.705 acceptance:1.0000 prior_scale:1.7138 likelihood_scale:1.5678 fixed_effect_weights:[1.3247846 0.44955173 0.38326803]\n", | |
"global_step: 201 loss: 5929.962 acceptance:0.9967 prior_scale:1.8976 likelihood_scale:0.4700 fixed_effect_weights:[ 1.421003 -0.2120042 0.427505 ]\n", | |
"global_step: 301 loss: 5078.015 acceptance:0.9756 prior_scale:1.8770 likelihood_scale:0.3634 fixed_effect_weights:[ 1.6500283 -0.42461014 0.42339408]\n", | |
"global_step: 401 loss: 4999.351 acceptance:0.9626 prior_scale:1.9097 likelihood_scale:0.3613 fixed_effect_weights:[ 1.9427483 -0.4560817 0.39843553]\n", | |
"global_step: 501 loss: 4922.601 acceptance:0.9501 prior_scale:1.9320 likelihood_scale:0.3590 fixed_effect_weights:[ 2.2689273 -0.48743185 0.3714246 ]\n", | |
"global_step: 601 loss: 4859.388 acceptance:0.9412 prior_scale:1.9476 likelihood_scale:0.3563 fixed_effect_weights:[ 2.6151083 -0.5074523 0.34457052]\n", | |
"global_step: 701 loss: 4792.275 acceptance:0.9387 prior_scale:1.9343 likelihood_scale:0.3542 fixed_effect_weights:[ 2.9732335 -0.5322577 0.32052314]\n", | |
"global_step: 801 loss: 4728.039 acceptance:0.9363 prior_scale:1.9771 likelihood_scale:0.3525 fixed_effect_weights:[ 3.3495574 -0.54722965 0.2905442 ]\n", | |
"global_step: 901 loss: 4844.631 acceptance:0.9397 prior_scale:1.9502 likelihood_scale:0.3528 fixed_effect_weights:[ 3.7223132 -0.58536816 0.26467618]\n", | |
"global_step:1001 loss: 4660.056 acceptance:0.9424 prior_scale:1.9584 likelihood_scale:0.3491 fixed_effect_weights:[ 4.0668674 -0.60071796 0.24298063]\n", | |
"global_step:1101 loss: 4619.966 acceptance:0.9437 prior_scale:1.9542 likelihood_scale:0.3472 fixed_effect_weights:[ 4.3808465 -0.6201686 0.22685133]\n", | |
"global_step:1201 loss: 4559.861 acceptance:0.9473 prior_scale:1.8865 likelihood_scale:0.3466 fixed_effect_weights:[ 4.7036843 -0.62965864 0.20515324]\n", | |
"global_step:1301 loss: 4489.408 acceptance:0.9485 prior_scale:1.8913 likelihood_scale:0.3515 fixed_effect_weights:[ 4.9884224 -0.66703933 0.19129947]\n", | |
"global_step:1401 loss: 4544.491 acceptance:0.9507 prior_scale:1.8677 likelihood_scale:0.3456 fixed_effect_weights:[ 5.262151 -0.65960604 0.17540511]\n", | |
"global_step:1500 loss: 4420.034 acceptance:0.9536 prior_scale:1.8611 likelihood_scale:0.3428 fixed_effect_weights:[ 5.5442433 -0.68676883 0.1565091 ]\n", | |
"CPU times: user 14min 45s, sys: 38min 23s, total: 53min 8s\n", | |
"Wall time: 1min 1s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"maxiter = int(1500)\n", | |
"num_accepted = 0\n", | |
"num_drawn = 0\n", | |
"for i in range(maxiter):\n", | |
" [\n", | |
" _,\n", | |
" global_step_,\n", | |
" loss_,\n", | |
" posterior_random_weights_,\n", | |
" kernel_results_,\n", | |
"# step_size_,\n", | |
" prior_scale_,\n", | |
" likelihood_scale_,\n", | |
" fixed_effect_weights_,\n", | |
" ] = sess.run([\n", | |
" train_op,\n", | |
" global_step,\n", | |
" loss,\n", | |
" posterior_random_weights,\n", | |
" kernel_results,\n", | |
"# step_size,\n", | |
" prior_scale,\n", | |
" likelihood_scale,\n", | |
" fixed_effect_weights,\n", | |
" ], feed_dict={init_random_weights: w_})\n", | |
" w_ = posterior_random_weights_[-1, :]\n", | |
" num_accepted += kernel_results_.is_accepted.sum()\n", | |
" num_drawn += kernel_results_.is_accepted.size\n", | |
" acceptance_rate = num_accepted / num_drawn\n", | |
" if i % 100 == 0 or i == maxiter - 1:\n", | |
" print('global_step:{:>4} loss:{: 9.3f} acceptance:{:.4f} '\n", | |
" 'prior_scale:{:.4f} likelihood_scale:{:.4f} '\n", | |
" 'fixed_effect_weights:{}'.format(\n", | |
" global_step_, loss_.mean(), acceptance_rate, \n", | |
" prior_scale_, likelihood_scale_, fixed_effect_weights_))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 139, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 15min 9s, sys: 38min 11s, total: 53min 20s\n", | |
"Wall time: 1min 1s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"posterior_random_weights_final, kernel_results_final = tfp.mcmc.sample_chain(\n", | |
" num_results=int(1e3),\n", | |
" num_burnin_steps=int(1e3),\n", | |
" current_state=init_random_weights,\n", | |
" kernel=tfp.mcmc.HamiltonianMonteCarlo(\n", | |
" target_log_prob_fn=unnormalized_posterior_log_prob,\n", | |
" num_leapfrog_steps=2,\n", | |
" step_size=0.015))\n", | |
"\n", | |
"[\n", | |
" posterior_random_weights_final_,\n", | |
" kernel_results_final_,\n", | |
"] = sess.run([\n", | |
" posterior_random_weights_final,\n", | |
" kernel_results_final,\n", | |
"], feed_dict={init_random_weights: w_})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 140, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"prior_scale: 1.8611326\n", | |
"likelihood_scale: 0.34281462\n", | |
"fixed_effect_weights: [ 5.5442433 -0.68676883 0.1565091 ]\n", | |
"acceptance rate final: 0.9015333333333333\n" | |
] | |
} | |
], | |
"source": [ | |
"print('prior_scale: ', prior_scale_)\n", | |
"print('likelihood_scale: ', likelihood_scale_)\n", | |
"print('fixed_effect_weights: ', fixed_effect_weights_)\n", | |
"print('acceptance rate final: ', kernel_results_final_.is_accepted.mean())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 143, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"MetropolisHastingsKernelResults(accepted_results=UncalibratedHamiltonianMonteCarloKernelResults(log_acceptance_correction=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_1/TensorArrayGatherV3:0' shape=(15000,) dtype=float32>, target_log_prob=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_2/TensorArrayGatherV3:0' shape=(15000,) dtype=float32>, grads_target_log_prob=[<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_3/TensorArrayGatherV3:0' shape=(15000, 195) dtype=float32>]), is_accepted=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_4/TensorArrayGatherV3:0' shape=(15000,) dtype=bool>, log_accept_ratio=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_5/TensorArrayGatherV3:0' shape=(15000,) dtype=float32>, proposed_state=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_6/TensorArrayGatherV3:0' shape=(15000, 195) dtype=float32>, proposed_results=UncalibratedHamiltonianMonteCarloKernelResults(log_acceptance_correction=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_7/TensorArrayGatherV3:0' shape=(15000,) dtype=float32>, target_log_prob=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_8/TensorArrayGatherV3:0' shape=(15000,) dtype=float32>, grads_target_log_prob=[<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_9/TensorArrayGatherV3:0' shape=(15000, 195) dtype=float32>]), extra=[])" | |
] | |
}, | |
"execution_count": 143, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"kernel_results_final" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment