Skip to content

Instantly share code, notes, and snippets.

@sadatnfs
Last active August 28, 2018 21:25
Show Gist options
  • Save sadatnfs/04f1fc5207fa498f715191df3ed9b3c9 to your computer and use it in GitHub Desktop.
Save sadatnfs/04f1fc5207fa498f715191df3ed9b3c9 to your computer and use it in GitHub Desktop.
Fitting a simple random intercept model on GDP data using TFP (following the radon example on TFP page)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
" import numpy as np\n",
" import pandas as pd\n",
" import tensorflow as tf\n",
" import matplotlib.pyplot as plt\n",
" from tensorflow.contrib.distributions import MultivariateNormalTriL\n",
" import tensorflow_probability as tfp\n",
" import warnings \n",
" import statsmodels.api as sm\n",
" import statsmodels.formula.api as smf"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_csv('/home/j/Project/IRH/Forecasting/gdp/data/RT_2018_GDP_use.csv')\n",
"\n",
"## Prep data\n",
"data = data[['iso3', 'year', 'ln_gdppc', 'ln_TFR', 'ln_pop']]\n",
"data['intercept'] = 1.\n",
"data = data.dropna()\n",
"\n",
"# Remap categories to start from 0 and end at max(category).\n",
"data['iso3'] = data['iso3'].astype('category').cat.codes\n",
"\n",
"## Number of REs\n",
"n_res = max(data.iso3) + 1\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inv_scale_transform = lambda y: np.log(y) # Not using TF here.\n",
"fwd_scale_transform = tf.exp"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"def _make_weights_prior(num_counties, dtype):\n",
" \"\"\"Returns a `len(iso3)` batch of univariate Normal.\"\"\"\n",
" raw_prior_scale = tf.get_variable(\n",
" name='raw_prior_scale',\n",
" initializer=np.array(inv_scale_transform(1.), dtype=dtype))\n",
" return tfp.distributions.Independent(\n",
" tfp.distributions.Normal(\n",
" loc=tf.zeros(n_res, dtype=dtype),\n",
" scale=fwd_scale_transform(raw_prior_scale)),\n",
" reinterpreted_batch_ndims=1)\n"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
"def _make_log_Y_likelihood(random_effect_weights, \n",
" ln_TFR,\n",
" ln_pop,\n",
" iso3,\n",
" iso3_int, \n",
" init_log_Y_stddev):\n",
" raw_likelihood_scale = tf.get_variable(\n",
" name='raw_likelihood_scale',\n",
" initializer=np.array(\n",
" inv_scale_transform(init_log_Y_stddev), dtype=dtype))\n",
" fixed_effect_weights = tf.get_variable(\n",
" name='fixed_effect_weights', \n",
"# initializer=np.array([0., 1., 2.], dtype=dtype)\n",
" shape = [3],\n",
" )\n",
" fixed_effects = fixed_effect_weights[0] + \\\n",
" fixed_effect_weights[1] * ln_TFR + \\\n",
" fixed_effect_weights[2] * ln_pop\n",
" random_effects = tf.gather( \\\n",
" random_effect_weights * iso3_int,\n",
" indices=tf.to_int32(iso3),\n",
" axis=-1)\n",
" linear_predictor = fixed_effects + random_effects\n",
" return tfp.distributions.Normal(\n",
" loc=linear_predictor, \\\n",
" scale=fwd_scale_transform(raw_likelihood_scale))"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
"def joint_log_prob(random_effect_weights, \n",
" ln_gdppc, \n",
" ln_TFR,\n",
" ln_pop,\n",
" iso3,\n",
" iso3_int, \n",
" dtype):\n",
" num_counties = len(iso3_int)\n",
" rv_weights = make_weights_prior(n_res, dtype)\n",
" rv_Y = make_log_Y_likelihood(\n",
" random_effect_weights,\n",
" ln_TFR,\n",
" ln_pop,\n",
" iso3,\n",
" iso3_int, \n",
" init_log_Y_stddev=1.)\n",
" return (rv_weights.log_prob(random_effect_weights)\n",
" + tf.reduce_sum(rv_Y.log_prob(ln_gdppc), axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
"# Specify unnormalized posterior.\n",
"def unnormalized_posterior_log_prob(random_effect_weights):\n",
" return joint_log_prob(\n",
" random_effect_weights = random_effect_weights, \n",
" ln_gdppc = dtype(data.ln_gdppc.values), \n",
" ln_TFR = dtype(data.ln_TFR.values), \n",
" ln_pop = dtype(data.ln_pop.values), \n",
" iso3 = np.int32(data.iso3.values),\n",
" iso3_int = iso3_int, \n",
" dtype = dtype)"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
"## Graph resettor\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter('ignore')\n",
" tf.reset_default_graph()\n",
" try:\n",
" sess.close()\n",
" except:\n",
" pass\n",
" sess = tf.InteractiveSession()"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
"## Create TF vars\n",
"make_weights_prior = tf.make_template(\n",
" name_='make_weights_prior', func_=_make_weights_prior)\n",
"\n",
"make_log_Y_likelihood = tf.make_template(\n",
" name_='make_log_Y_likelihood', func_=_make_log_Y_likelihood)\n",
"\n",
"dtype = np.float32\n",
"iso3_int = data[\n",
" ['iso3', 'intercept']].drop_duplicates().values[:, 1]\n",
"iso3_int = iso3_int.astype(dtype)"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
"outputs": [],
"source": [
"# Set-up E-step.\n",
"\n",
"# step_size = tf.get_variable(\n",
"# 'step_size',\n",
"# initializer=np.array(0.2, dtype=dtype),\n",
"# trainable=False)\n",
"\n",
"hmc = tfp.mcmc.HamiltonianMonteCarlo(\n",
" target_log_prob_fn=unnormalized_posterior_log_prob,\n",
" num_leapfrog_steps=2,\n",
" step_size=0.015,\n",
"# step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(),\n",
" state_gradients_are_stopped=True)\n",
"\n",
"init_random_weights = tf.placeholder(dtype, shape=[len(iso3_int)])\n",
"\n",
"posterior_random_weights, kernel_results = tfp.mcmc.sample_chain(\n",
" num_results=3,\n",
" num_burnin_steps=0,\n",
" num_steps_between_results=0,\n",
" current_state=init_random_weights,\n",
" kernel=hmc)"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [],
"source": [
"# Set-up M-step.\n",
"loss = -tf.reduce_mean(kernel_results.accepted_results.target_log_prob)\n",
"global_step = tf.train.get_or_create_global_step()\n",
"learning_rate = tf.train.exponential_decay(\n",
" learning_rate=0.1,\n",
" global_step=global_step,\n",
" decay_steps=250,\n",
" decay_rate=0.99)\n",
"\n",
"optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n",
"train_op = optimizer.minimize(loss, global_step=global_step)"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
"# Initialize all variables.\n",
"init_op = tf.global_variables_initializer()"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [],
"source": [
"# Grab variable handles for diagnostic purposes.\n",
"with tf.variable_scope('make_weights_prior', reuse=True):\n",
" prior_scale = fwd_scale_transform(tf.get_variable(name='raw_prior_scale', dtype=dtype))\n",
"\n",
"with tf.variable_scope('make_log_Y_likelihood', reuse=True):\n",
" likelihood_scale = fwd_scale_transform(tf.get_variable(name='raw_likelihood_scale', dtype=dtype))\n",
" fixed_effect_weights = tf.get_variable(name='fixed_effect_weights', dtype=dtype)"
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {},
"outputs": [],
"source": [
"init_op.run()\n",
"w_ = np.zeros([len(iso3_int)], dtype=dtype)"
]
},
{
"cell_type": "code",
"execution_count": 137,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"global_step: 1 loss: 402393.500 acceptance:1.0000 prior_scale:1.0000 likelihood_scale:1.0000 fixed_effect_weights:[0.86661494 0.40798506 0.01009426]\n",
"global_step: 101 loss: 16923.705 acceptance:1.0000 prior_scale:1.7138 likelihood_scale:1.5678 fixed_effect_weights:[1.3247846 0.44955173 0.38326803]\n",
"global_step: 201 loss: 5929.962 acceptance:0.9967 prior_scale:1.8976 likelihood_scale:0.4700 fixed_effect_weights:[ 1.421003 -0.2120042 0.427505 ]\n",
"global_step: 301 loss: 5078.015 acceptance:0.9756 prior_scale:1.8770 likelihood_scale:0.3634 fixed_effect_weights:[ 1.6500283 -0.42461014 0.42339408]\n",
"global_step: 401 loss: 4999.351 acceptance:0.9626 prior_scale:1.9097 likelihood_scale:0.3613 fixed_effect_weights:[ 1.9427483 -0.4560817 0.39843553]\n",
"global_step: 501 loss: 4922.601 acceptance:0.9501 prior_scale:1.9320 likelihood_scale:0.3590 fixed_effect_weights:[ 2.2689273 -0.48743185 0.3714246 ]\n",
"global_step: 601 loss: 4859.388 acceptance:0.9412 prior_scale:1.9476 likelihood_scale:0.3563 fixed_effect_weights:[ 2.6151083 -0.5074523 0.34457052]\n",
"global_step: 701 loss: 4792.275 acceptance:0.9387 prior_scale:1.9343 likelihood_scale:0.3542 fixed_effect_weights:[ 2.9732335 -0.5322577 0.32052314]\n",
"global_step: 801 loss: 4728.039 acceptance:0.9363 prior_scale:1.9771 likelihood_scale:0.3525 fixed_effect_weights:[ 3.3495574 -0.54722965 0.2905442 ]\n",
"global_step: 901 loss: 4844.631 acceptance:0.9397 prior_scale:1.9502 likelihood_scale:0.3528 fixed_effect_weights:[ 3.7223132 -0.58536816 0.26467618]\n",
"global_step:1001 loss: 4660.056 acceptance:0.9424 prior_scale:1.9584 likelihood_scale:0.3491 fixed_effect_weights:[ 4.0668674 -0.60071796 0.24298063]\n",
"global_step:1101 loss: 4619.966 acceptance:0.9437 prior_scale:1.9542 likelihood_scale:0.3472 fixed_effect_weights:[ 4.3808465 -0.6201686 0.22685133]\n",
"global_step:1201 loss: 4559.861 acceptance:0.9473 prior_scale:1.8865 likelihood_scale:0.3466 fixed_effect_weights:[ 4.7036843 -0.62965864 0.20515324]\n",
"global_step:1301 loss: 4489.408 acceptance:0.9485 prior_scale:1.8913 likelihood_scale:0.3515 fixed_effect_weights:[ 4.9884224 -0.66703933 0.19129947]\n",
"global_step:1401 loss: 4544.491 acceptance:0.9507 prior_scale:1.8677 likelihood_scale:0.3456 fixed_effect_weights:[ 5.262151 -0.65960604 0.17540511]\n",
"global_step:1500 loss: 4420.034 acceptance:0.9536 prior_scale:1.8611 likelihood_scale:0.3428 fixed_effect_weights:[ 5.5442433 -0.68676883 0.1565091 ]\n",
"CPU times: user 14min 45s, sys: 38min 23s, total: 53min 8s\n",
"Wall time: 1min 1s\n"
]
}
],
"source": [
"%%time\n",
"maxiter = int(1500)\n",
"num_accepted = 0\n",
"num_drawn = 0\n",
"for i in range(maxiter):\n",
" [\n",
" _,\n",
" global_step_,\n",
" loss_,\n",
" posterior_random_weights_,\n",
" kernel_results_,\n",
"# step_size_,\n",
" prior_scale_,\n",
" likelihood_scale_,\n",
" fixed_effect_weights_,\n",
" ] = sess.run([\n",
" train_op,\n",
" global_step,\n",
" loss,\n",
" posterior_random_weights,\n",
" kernel_results,\n",
"# step_size,\n",
" prior_scale,\n",
" likelihood_scale,\n",
" fixed_effect_weights,\n",
" ], feed_dict={init_random_weights: w_})\n",
" w_ = posterior_random_weights_[-1, :]\n",
" num_accepted += kernel_results_.is_accepted.sum()\n",
" num_drawn += kernel_results_.is_accepted.size\n",
" acceptance_rate = num_accepted / num_drawn\n",
" if i % 100 == 0 or i == maxiter - 1:\n",
" print('global_step:{:>4} loss:{: 9.3f} acceptance:{:.4f} '\n",
" 'prior_scale:{:.4f} likelihood_scale:{:.4f} '\n",
" 'fixed_effect_weights:{}'.format(\n",
" global_step_, loss_.mean(), acceptance_rate, \n",
" prior_scale_, likelihood_scale_, fixed_effect_weights_))"
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 15min 9s, sys: 38min 11s, total: 53min 20s\n",
"Wall time: 1min 1s\n"
]
}
],
"source": [
"%%time\n",
"posterior_random_weights_final, kernel_results_final = tfp.mcmc.sample_chain(\n",
" num_results=int(1e3),\n",
" num_burnin_steps=int(1e3),\n",
" current_state=init_random_weights,\n",
" kernel=tfp.mcmc.HamiltonianMonteCarlo(\n",
" target_log_prob_fn=unnormalized_posterior_log_prob,\n",
" num_leapfrog_steps=2,\n",
" step_size=0.015))\n",
"\n",
"[\n",
" posterior_random_weights_final_,\n",
" kernel_results_final_,\n",
"] = sess.run([\n",
" posterior_random_weights_final,\n",
" kernel_results_final,\n",
"], feed_dict={init_random_weights: w_})"
]
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"prior_scale: 1.8611326\n",
"likelihood_scale: 0.34281462\n",
"fixed_effect_weights: [ 5.5442433 -0.68676883 0.1565091 ]\n",
"acceptance rate final: 0.9015333333333333\n"
]
}
],
"source": [
"print('prior_scale: ', prior_scale_)\n",
"print('likelihood_scale: ', likelihood_scale_)\n",
"print('fixed_effect_weights: ', fixed_effect_weights_)\n",
"print('acceptance rate final: ', kernel_results_final_.is_accepted.mean())"
]
},
{
"cell_type": "code",
"execution_count": 143,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MetropolisHastingsKernelResults(accepted_results=UncalibratedHamiltonianMonteCarloKernelResults(log_acceptance_correction=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_1/TensorArrayGatherV3:0' shape=(15000,) dtype=float32>, target_log_prob=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_2/TensorArrayGatherV3:0' shape=(15000,) dtype=float32>, grads_target_log_prob=[<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_3/TensorArrayGatherV3:0' shape=(15000, 195) dtype=float32>]), is_accepted=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_4/TensorArrayGatherV3:0' shape=(15000,) dtype=bool>, log_accept_ratio=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_5/TensorArrayGatherV3:0' shape=(15000,) dtype=float32>, proposed_state=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_6/TensorArrayGatherV3:0' shape=(15000, 195) dtype=float32>, proposed_results=UncalibratedHamiltonianMonteCarloKernelResults(log_acceptance_correction=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_7/TensorArrayGatherV3:0' shape=(15000,) dtype=float32>, target_log_prob=<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_8/TensorArrayGatherV3:0' shape=(15000,) dtype=float32>, grads_target_log_prob=[<tf.Tensor 'mcmc_sample_chain_3/scan/TensorArrayStack_9/TensorArrayGatherV3:0' shape=(15000, 195) dtype=float32>]), extra=[])"
]
},
"execution_count": 143,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kernel_results_final"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment