Last active
August 28, 2018 21:26
-
-
Save sadatnfs/dbf2b9f1b4ee3c37df9220ee121649e5 to your computer and use it in GitHub Desktop.
Fitting a simple random intercept model on GDP data using TFP (following the LME example on TFP page), using both iterations and convergence
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%capture \n", | |
"\n", | |
"%matplotlib inline\n", | |
"\n", | |
"import IPython\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import requests\n", | |
"import tensorflow as tf\n", | |
"import tensorflow_probability as tfp\n", | |
"import warnings\n", | |
"\n", | |
"from tensorflow_probability import edward2 as ed\n", | |
"tfd = tfp.distributions\n", | |
"\n", | |
"from keras.constraints import non_neg\n", | |
"\n", | |
"import statsmodels.api as sm\n", | |
"import statsmodels.formula.api as smf\n", | |
" \n", | |
"plt.style.use('ggplot')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Get GDP data\n", | |
"## Get GDP data for testing\n", | |
"data = pd.read_csv('/home/j/Project/IRH/Forecasting/gdp/data/RT_2018_GDP_use.csv')\n", | |
"\n", | |
"## Prep data\n", | |
"data = data[['iso3', 'year', 'ln_gdppc', 'ln_TFR', 'ln_pop']]\n", | |
"data['intercept'] = 1.\n", | |
"data = data.dropna()\n", | |
"\n", | |
"# Remap categories to start from 0 and end at max(category).\n", | |
"data['iso3'] = data['iso3'].astype('category').cat.codes\n", | |
"\n", | |
"## Number of REs\n", | |
"n_res = max(data.iso3) + 1\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Model:\n", | |
"$$ ln(GDPpc) = \\alpha + \\alpha_i + \\beta \\ln(TFR) + \\gamma ln(pop) + \\epsilon_{i,t} $$ \n", | |
"$$ \\epsilon_{i,t} \\sim \\mathcal{N}(0, \\sigma^2) $$ \n", | |
"$$ \\alpha_i \\sim \\mathcal{N}(0, \\sigma_a^2) $$\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table class=\"simpletable\">\n", | |
"<tr>\n", | |
" <td>Model:</td> <td>MixedLM</td> <td>Dependent Variable:</td> <td>ln_gdppc</td> \n", | |
"</tr>\n", | |
"<tr>\n", | |
" <td>No. Observations:</td> <td>11505</td> <td>Method:</td> <td>REML</td> \n", | |
"</tr>\n", | |
"<tr>\n", | |
" <td>No. Groups:</td> <td>195</td> <td>Scale:</td> <td>0.1135</td> \n", | |
"</tr>\n", | |
"<tr>\n", | |
" <td>Min. group size:</td> <td>59</td> <td>Likelihood:</td> <td>-4452.9981</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <td>Max. group size:</td> <td>59</td> <td>Converged:</td> <td>Yes</td> \n", | |
"</tr>\n", | |
"<tr>\n", | |
" <td>Mean group size:</td> <td>59.0</td> <td></td> <td></td> \n", | |
"</tr>\n", | |
"</table>\n", | |
"<table class=\"simpletable\">\n", | |
"<tr>\n", | |
" <td></td> <th>Coef.</th> <th>Std.Err.</th> <th>z</th> <th>P>|z|</th> <th>[0.025</th> <th>0.975]</th>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Intercept</th> <td>10.938</td> <td>0.204</td> <td>53.629</td> <td>0.000</td> <td>10.538</td> <td>11.337</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>ln_TFR</th> <td>-0.926</td> <td>0.015</td> <td>-63.159</td> <td>0.000</td> <td>-0.955</td> <td>-0.898</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>ln_pop</th> <td>-0.094</td> <td>0.011</td> <td>-8.281</td> <td>0.000</td> <td>-0.117</td> <td>-0.072</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Group Var</th> <td>1.322</td> <td>0.403</td> <td></td> <td></td> <td></td> <td></td> \n", | |
"</tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<class 'statsmodels.iolib.summary2.Summary'>\n", | |
"\"\"\"\n", | |
" Mixed Linear Model Regression Results\n", | |
"========================================================\n", | |
"Model: MixedLM Dependent Variable: ln_gdppc \n", | |
"No. Observations: 11505 Method: REML \n", | |
"No. Groups: 195 Scale: 0.1135 \n", | |
"Min. group size: 59 Likelihood: -4452.9981\n", | |
"Max. group size: 59 Converged: Yes \n", | |
"Mean group size: 59.0 \n", | |
"--------------------------------------------------------\n", | |
" Coef. Std.Err. z P>|z| [0.025 0.975]\n", | |
"--------------------------------------------------------\n", | |
"Intercept 10.938 0.204 53.629 0.000 10.538 11.337\n", | |
"ln_TFR -0.926 0.015 -63.159 0.000 -0.955 -0.898\n", | |
"ln_pop -0.094 0.011 -8.281 0.000 -0.117 -0.072\n", | |
"Group Var 1.322 0.403 \n", | |
"========================================================\n", | |
"\n", | |
"\"\"\"" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"## Fit in StatsModels first\n", | |
"md = smf.mixedlm(\"ln_gdppc ~ 1 + ln_TFR + ln_pop\", data, groups=data[\"iso3\"])\n", | |
"sm_fit = md.fit() \n", | |
"sm_fit.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Get the data for 'features' and 'labels' (xvars and yvar)\n", | |
"get_value = lambda dataframe, key, dtype: dataframe[key].values.astype(dtype)\n", | |
"\n", | |
"## We get the country codes as integers and then attach on the fixed effects\n", | |
"features_train = {\n", | |
" k: get_value(data, key=k, dtype=np.int32)\n", | |
" for k in ['iso3']}\n", | |
"features_train.update({\n", | |
" k: get_value(data, key=k, dtype=np.float64)\n", | |
" for k in ['intercept', 'ln_TFR', 'ln_pop']})\n", | |
"\n", | |
"## Get yvar\n", | |
"labels_train = get_value(data, key='ln_gdppc', dtype=np.float64)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Graphing the model funk\n", | |
"def strip_consts(graph_def, max_const_size=32):\n", | |
" \"\"\"Strip large constant values from graph_def.\"\"\"\n", | |
" strip_def = tf.GraphDef()\n", | |
" for n0 in graph_def.node:\n", | |
" n = strip_def.node.add()\n", | |
" n.MergeFrom(n0)\n", | |
" if n.op == 'Const':\n", | |
" tensor = n.attr['value'].tensor\n", | |
" size = len(tensor.tensor_content)\n", | |
" if size > max_const_size:\n", | |
" tensor.tensor_content = bytes(\"<stripped %d bytes>\"%size, 'utf-8')\n", | |
" return strip_def\n", | |
"\n", | |
"def draw_graph(model, *args, **kwargs):\n", | |
" \"\"\"Visualize TensorFlow graph.\"\"\"\n", | |
" graph = tf.Graph()\n", | |
" with graph.as_default():\n", | |
" model(*args, **kwargs)\n", | |
" graph_def = graph.as_graph_def()\n", | |
" strip_def = strip_consts(graph_def, max_const_size=32)\n", | |
" code = \"\"\"\n", | |
" <script>\n", | |
" function load() {{\n", | |
" document.getElementById(\"{id}\").pbtxt = {data};\n", | |
" }}\n", | |
" </script>\n", | |
" <link rel=\"import\" href=\"https://tensorboard.appspot.com/tf-graph-basic.build.html\" onload=load()>\n", | |
" <div style=\"height:600px\">\n", | |
" <tf-graph-basic id=\"{id}\"></tf-graph-basic>\n", | |
" </div>\n", | |
" \"\"\".format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))\n", | |
"\n", | |
" iframe = \"\"\"\n", | |
" <iframe seamless style=\"width:1200px;height:620px;border:0\" srcdoc=\"{}\"></iframe>\n", | |
" \"\"\".format(code.replace('\"', '"'))\n", | |
" IPython.display.display(IPython.display.HTML(iframe))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Run (1):\n", | |
"### We will first use purely iterations to fit the model, without any convergence criteria" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Define our mixed effects model, returning the prediction\n", | |
"def linear_mixed_effects_model(features):\n", | |
" \n", | |
" # Set up fixed effects and other parameters.\n", | |
" intercept = tf.get_variable(\"intercept\", []) # alpha in eq\n", | |
" \n", | |
" ## Constraint is not needed, but can be done as such\n", | |
" # beta in eq \n", | |
"# effect_ln_TFR = tf.get_variable(\"effect_ln_TFR\", [],\n", | |
"# constraint=lambda x: tf.clip_by_value(x, -np.infty, 0)) \n", | |
" \n", | |
" stddev_raw_tfr = tf.exp(tf.get_variable(\"stddev_raw_tfr\", [])) \n", | |
" effect_ln_TFR = ed.Normal(loc=tf.zeros(1), \n", | |
" scale=stddev_raw_tfr, \n", | |
" name = 'effect_ln_TFR')\n", | |
" \n", | |
" # gamma in eq\n", | |
" effect_ln_pop = tf.get_variable(\"effect_ln_pop\", []) \n", | |
" \n", | |
" ## The two variances (exp to force positive estimate)\n", | |
" stddev_iso3 = tf.exp(tf.get_variable(\"stddev_iso3\", [])) \n", | |
" model_stddev = tf.exp(tf.get_variable(\"model_stddev\", []))\n", | |
" \n", | |
" # Set up random effects.\n", | |
" effect_iso3 = ed.MultivariateNormalDiag( \n", | |
" loc=tf.zeros(n_res),\n", | |
" scale_identity_multiplier=stddev_iso3,\n", | |
" name=\"effect_iso3\")\n", | |
" \n", | |
" # Set up likelihood given fixed and random effects.\n", | |
" # Note we use `tf.gather` instead of matrix-multiplying a design matrix of\n", | |
" # one-hot vectors. The latter is memory-intensive if there are many groups. \n", | |
" \n", | |
" ln_gdppc = ed.Normal(\n", | |
" loc=((intercept * features['intercept']) +\n", | |
" (effect_ln_pop * features[\"ln_pop\"]) +\n", | |
" (effect_ln_TFR * features[\"ln_TFR\"]) + \n", | |
" tf.gather(effect_iso3, features[\"iso3\"])),\n", | |
" scale=model_stddev, name=\"ln_gdppc\")\n", | |
" return ln_gdppc\n", | |
"\n", | |
"# Unnormalized target density as a function of states of REs.\n", | |
"def target_log_prob_fn(effect_iso3, effect_ln_TFR): \n", | |
" return log_joint( # fix `features` and `outcome` to the training data\n", | |
" features=features_train,\n", | |
" effect_iso3=effect_iso3, \n", | |
" effect_ln_TFR=effect_ln_TFR,\n", | |
" ln_gdppc=labels_train)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Set up the graph, and the backend of the E-M algorithm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Wrap model in a template. All calls to the model template will use the same\n", | |
"# TensorFlow variables.\n", | |
"\n", | |
"tf.reset_default_graph()\n", | |
"model_template = tf.make_template(\"model\", linear_mixed_effects_model)\n", | |
"\n", | |
"## Make joint LL\n", | |
"log_joint = ed.make_log_joint_fn(model_template)\n", | |
"\n", | |
"# draw_graph(linear_mixed_effects_model, features_train) ## draws the computational graph\n", | |
"\n", | |
"## Set up E-step (MCMC) for RE vars\n", | |
"effect_iso3 = tf.get_variable(\"effect_iso3\", [n_res], trainable=False)\n", | |
"effect_ln_TFR = tf.get_variable(\"effect_ln_TFR\", [1], trainable=False)\n", | |
"\n", | |
"## Global step variable holder\n", | |
"global_step = tf.Variable(0, trainable=False, name=\"global_step\")\n", | |
"\n", | |
"## Set up MCMC object\n", | |
"hmc = tfp.mcmc.HamiltonianMonteCarlo(\n", | |
" target_log_prob_fn=target_log_prob_fn,\n", | |
" step_size=0.015,\n", | |
" num_leapfrog_steps=3)\n", | |
"\n", | |
"## RE state to update\n", | |
"current_state = [effect_iso3, effect_ln_TFR]\n", | |
"\n", | |
"## Update state\n", | |
"with warnings.catch_warnings():\n", | |
" warnings.simplefilter(\"ignore\")\n", | |
" next_state, kernel_results = hmc.one_step(current_state=current_state,\n", | |
" previous_kernel_results=hmc.bootstrap_results(current_state))\n", | |
"\n", | |
"## Update Expectation\n", | |
"expectation_update = tf.group(effect_iso3.assign(next_state[0]),\\\n", | |
" effect_ln_TFR.assign(next_state[1]))\n", | |
"\n", | |
"\n", | |
"# Set up M-step (gradient descent), using Adam Optimizer,\n", | |
"# where we learn the learning rate (not needed but why not)\n", | |
"with tf.control_dependencies([expectation_update]):\n", | |
" loss = -target_log_prob_fn(effect_iso3, effect_ln_TFR) \n", | |
" learning_rate = tf.train.exponential_decay(learning_rate=0.01, \n", | |
" global_step=global_step,\n", | |
" decay_steps=250, \n", | |
" decay_rate=0.90, staircase=True)\n", | |
" optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n", | |
" minimization_update = optimizer.minimize(loss)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Fit the first run" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Warm-Up Iteration: 0 Acceptance Rate: 0.000\n", | |
"Warm-Up Iteration: 250 Acceptance Rate: 0.000\n", | |
"Warm-Up Iteration: 500 Acceptance Rate: 0.000\n", | |
"Warm-Up Iteration: 749 Acceptance Rate: 0.000\n", | |
"Iteration: 0 Acceptance Rate: 0.000 Loss: 143640016.000,ln_TFR:-0.559, ln_pop:-1.441, intercept:-1.694\n", | |
"Iteration: 0 Acceptance Rate: 0.000 Loss: 139404736.000,ln_TFR:-0.559, ln_pop:-1.431, intercept:-1.684\n", | |
"Iteration: 0 Acceptance Rate: 0.000 Loss: 135291040.000,ln_TFR:-0.559, ln_pop:-1.421, intercept:-1.674\n", | |
"Iteration: 0 Acceptance Rate: 0.000 Loss: 131297704.000,ln_TFR:-0.559, ln_pop:-1.411, intercept:-1.664\n", | |
"Iteration: 0 Acceptance Rate: 0.000 Loss: 127423352.000,ln_TFR:-0.559, ln_pop:-1.401, intercept:-1.654\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"num_warmup_iters = 750\n", | |
"num_iters = 3000\n", | |
"num_accepted = 0\n", | |
"\n", | |
"## Record samples of the fitted params and the loss\n", | |
"effect_iso3_samples = np.zeros([num_iters, n_res])\n", | |
"effect_ln_TFR_samples = np.zeros([num_iters])\n", | |
"effect_ln_pop_samples = np.zeros([num_iters])\n", | |
"effect_int_samples = np.zeros([num_iters]) \n", | |
"loss_history = np.zeros([num_iters])\n", | |
"\n", | |
"## Push fixed effects into scope so that we can record them \n", | |
"## (by default, only the REs get recorded since they're in the state)\n", | |
"with tf.variable_scope('model', reuse=True):\n", | |
"# ln_tfr_wghts = tf.get_variable(name='effect_ln_TFR', dtype=np.float32)\n", | |
" ln_pop_wghts = tf.get_variable(name='effect_ln_pop', dtype=np.float32)\n", | |
" int_wghts = tf.get_variable(name='intercept', dtype=np.float32) \n", | |
"\n", | |
"## Start our sesh\n", | |
"with tf.Session() as sess:\n", | |
" sess.run(tf.global_variables_initializer())\n", | |
"\n", | |
" # Run warm-up stage.\n", | |
" for t in range(num_warmup_iters):\n", | |
" _, is_accepted_val = sess.run(\n", | |
" [expectation_update, kernel_results.is_accepted])\n", | |
" num_accepted += is_accepted_val\n", | |
" if t % 250 == 0 or t == num_warmup_iters - 1:\n", | |
" print(\"Warm-Up Iteration: {:>3} Acceptance Rate: {:.3f}\".format(\n", | |
" t, num_accepted / (t + 1)))\n", | |
"\n", | |
" num_accepted = 0 # reset acceptance rate counter\n", | |
"\n", | |
" # Run iterations (no convergence criteria)\n", | |
" for t in range(num_iters):\n", | |
" for _ in range(5): # run 5 MCMC iterations before every joint EM update\n", | |
" _ = sess.run(expectation_update)\n", | |
" [\n", | |
" _,\n", | |
" _,\n", | |
" effect_iso3_val,\n", | |
" ln_tfr_val, \n", | |
" ln_pop_val, \n", | |
" int_val,\n", | |
" is_accepted_val,\n", | |
" loss_val,\n", | |
" ] = sess.run([\n", | |
" expectation_update,\n", | |
" minimization_update,\n", | |
" effect_iso3,\n", | |
" #ln_tfr_wghts, \n", | |
" effect_ln_TFR[0],\n", | |
" ln_pop_wghts, \n", | |
" int_wghts,\n", | |
" kernel_results.is_accepted,\n", | |
" loss,\n", | |
" ])\n", | |
" effect_iso3_samples[t, :] = effect_iso3_val \n", | |
" effect_ln_TFR_samples[t] = ln_tfr_val\n", | |
" effect_ln_pop_samples[t] = ln_pop_val \n", | |
" effect_int_samples[t] = int_val\n", | |
" num_accepted+= is_accepted_val\n", | |
" loss_history[t] = loss_val\n", | |
" if t % 1000 == 0 :\n", | |
" print(\"Iteration: {} Acceptance Rate: {:.3f} Loss: {:.3f},ln_TFR:{:.3f}, ln_pop:{:.3f}, intercept:{:.3f}\".\n", | |
" format(\n", | |
" t, num_accepted / (t + 1), \\\n", | |
" loss_val, ln_tfr_val, ln_pop_val, int_val))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Run (2):\n", | |
"### Fit using parameter and loss convergence" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def linear_mixed_effects_model(features):\n", | |
" \n", | |
" # Set up fixed effects and other parameters.\n", | |
" intercept = tf.get_variable(\"intercept\", []) # alpha in eq\n", | |
" effect_ln_TFR = tf.get_variable(\"effect_ln_TFR\", [],\n", | |
" constraint=lambda x: tf.clip_by_value(x, -np.infty, 0)) # beta in eq\n", | |
" effect_ln_pop = tf.get_variable(\"effect_ln_pop\", []) # gamma in eq\n", | |
" stddev_iso3 = tf.exp(tf.get_variable(\"stddev_raw_iso3\", []))\n", | |
" model_stddev = tf.exp(tf.get_variable(\"model_stddev\", []))\n", | |
"\n", | |
" # Set up random effects.\n", | |
" effect_iso3 = ed.MultivariateNormalDiag( \n", | |
" loc=tf.zeros(n_res),\n", | |
" scale_identity_multiplier=stddev_iso3,\n", | |
" name=\"effect_iso3\") \n", | |
"\n", | |
"# The following allows us to define priors over params..?\n", | |
"# stddev_pop = tf.exp(tf.get_variable(\"stddev_pop\", []))\n", | |
"# effect_ln_pop = ed.Normal(loc = tf.get_variable(\"effect_ln_pop\", [1]), scale = stddev_pop,\n", | |
"# name=\"effect_ln_pop\")\n", | |
"\n", | |
"\n", | |
" # Set up likelihood given fixed and random effects.\n", | |
" ln_gdppc = ed.Normal(\n", | |
" loc=((intercept * features['intercept']) +\n", | |
" (effect_ln_TFR * features[\"ln_TFR\"]) +\n", | |
" (effect_ln_pop * features[\"ln_pop\"]) +\n", | |
" tf.gather(effect_iso3, features[\"iso3\"])),\n", | |
" scale=model_stddev, name=\"ln_gdppc\")\n", | |
" return ln_gdppc\n", | |
"\n", | |
"def target_log_prob_fn(effect_iso3):\n", | |
" return log_joint(\n", | |
" features=features_train,\n", | |
" effect_iso3=effect_iso3,\n", | |
"# effect_ln_TFR=effect_ln_TFR,\n", | |
"# effect_ln_pop=effect_ln_pop,\n", | |
"# intercept=intercept,\n", | |
" ln_gdppc=labels_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Wrap model in a template. All calls to the model template will use the same\n", | |
"# TensorFlow variables.\n", | |
"tf.reset_default_graph()\n", | |
"model_template = tf.make_template(\"model\", linear_mixed_effects_model)\n", | |
"\n", | |
"## Make joint LL\n", | |
"log_joint = ed.make_log_joint_fn(model_template)\n", | |
"\n", | |
"# draw_graph(linear_mixed_effects_model, features_train)\n", | |
"## Set up E-step (MCMC) for RE vars\n", | |
"effect_iso3 = tf.get_variable(\"effect_iso3\", [n_res], trainable=False)\n", | |
"# effect_ln_TFR = tf.get_variable(\"effect_ln_TFR\", [1], trainable=False)\n", | |
"# effect_ln_pop = tf.get_variable(\"effect_ln_pop\", [1], trainable=False)\n", | |
"\n", | |
"## Global step\n", | |
"global_step = tf.Variable(0, trainable=False, name=\"global_step\")\n", | |
"\n", | |
"## RE state to update\n", | |
"current_state = [effect_iso3]\n", | |
"\n", | |
"## Set up MCMC object\n", | |
"hmc = tfp.mcmc.HamiltonianMonteCarlo(\n", | |
" target_log_prob_fn=target_log_prob_fn,\n", | |
" step_size=0.02,\n", | |
" num_leapfrog_steps=2)\n", | |
"\n", | |
"## Update state\n", | |
"with warnings.catch_warnings():\n", | |
" warnings.simplefilter(\"ignore\")\n", | |
" next_state, kernel_results = hmc.one_step(\n", | |
" current_state=current_state,\n", | |
" previous_kernel_results=hmc.bootstrap_results(current_state))\n", | |
"\n", | |
"## Update Exp\n", | |
"expectation_update = tf.group(effect_iso3.assign(next_state[0]))\n", | |
"\n", | |
"# Set up M-step (gradient descent).\n", | |
"with tf.control_dependencies([expectation_update]):\n", | |
" loss = -target_log_prob_fn(effect_iso3 ) \n", | |
" learning_rate = tf.train.exponential_decay(learning_rate=0.01, \n", | |
" global_step=global_step,\n", | |
" decay_steps=20, \n", | |
" decay_rate=0.90, staircase=True)\n", | |
" optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n", | |
" minimization_update = optimizer.minimize(loss)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Fit" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Warm-Up Iteration: 0 \n", | |
"Warm-Up Iteration: 500 \n", | |
"Warm-Up Iteration: 999 \n", | |
"Run iterations with loss convergence criteria\n", | |
"Iteration: 0 Loss: 23787.656, ln_TFR:-1.380, ln_pop-0.874 intercept:0.540\n", | |
"Iteration: 1000 Loss: 4630.398, ln_TFR:-0.913, ln_pop-0.085 intercept:-5.559\n", | |
"Iteration: 2000 Loss: 4603.788, ln_TFR:-0.915, ln_pop-0.090 intercept:-5.253\n", | |
"Iteration: 3000 Loss: 4619.049, ln_TFR:-0.903, ln_pop-0.070 intercept:-4.676\n", | |
"CPU times: user 2min 21s, sys: 3.9 s, total: 2min 25s\n", | |
"Wall time: 1min 26s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"## Set up the convergence tolerances\n", | |
"TOL_PARAM, TOL_LOSS = 1e-7, 1e-7\n", | |
"num_warmup_iters = 1000\n", | |
"MAX_ITER = 3500\n", | |
"num_accepted = 0\n", | |
"\n", | |
"## Record samples\n", | |
"effect_iso3_samples = np.zeros([MAX_ITER, n_res])\n", | |
"effect_ln_TFR_samples = np.zeros([MAX_ITER])\n", | |
"effect_ln_pop_samples = np.zeros([MAX_ITER])\n", | |
"effect_int_samples = np.zeros([MAX_ITER]) \n", | |
"loss_history = np.zeros([MAX_ITER])\n", | |
"\n", | |
"# Push fixed effects into scope so that we can record them\n", | |
"with tf.variable_scope('model', reuse=True):\n", | |
" ln_tfr_wghts = tf.get_variable(name='effect_ln_TFR', dtype=np.float32)\n", | |
" ln_pop_wghts = tf.get_variable(name='effect_ln_pop', dtype=np.float32)\n", | |
" int_wghts = tf.get_variable(name='intercept', dtype=np.float32) \n", | |
"\n", | |
"sess = tf.Session()\n", | |
"sess.run(tf.global_variables_initializer())\n", | |
"\n", | |
"# Run warm-up stage.\n", | |
"for t in range(num_warmup_iters):\n", | |
" _, is_accepted_val, ln_tfr_val, ln_pop_val, int_val = sess.run(\\\n", | |
" [expectation_update, kernel_results.is_accepted, ln_tfr_wghts, ln_pop_wghts, int_wghts,])\n", | |
" num_accepted += is_accepted_val\n", | |
" if t % 500 == 0 or t == num_warmup_iters - 1:\n", | |
" print(\"Warm-Up Iteration: {:>3} \".format(t ))\n", | |
"\n", | |
"# reset acceptance rate counter\n", | |
"num_accepted = 0 \n", | |
"\n", | |
"print(\"Run iterations with loss convergence criteria\")\n", | |
"saver = tf.train.Saver()\n", | |
"\n", | |
"for t in range(MAX_ITER): \n", | |
" for _ in range(5): # run 5 MCMC iterations before every joint EM update\n", | |
" _ = sess.run(expectation_update)\n", | |
" [\n", | |
" _,\n", | |
" _,\n", | |
" effect_iso3_val,\n", | |
" ln_tfr_val, ln_pop_val, int_val, \n", | |
" is_accepted_val,\n", | |
" loss_val,\n", | |
" ] = sess.run([\n", | |
" expectation_update,\n", | |
" minimization_update,\n", | |
" effect_iso3,\n", | |
" ln_tfr_wghts, ln_pop_wghts, int_wghts, \n", | |
" kernel_results.is_accepted,\n", | |
" loss,\n", | |
" ])\n", | |
"\n", | |
"\n", | |
" effect_iso3_samples[t, :] = effect_iso3_val \n", | |
" effect_ln_TFR_samples[t] = ln_tfr_val\n", | |
" effect_ln_pop_samples[t] = ln_pop_val \n", | |
" effect_int_samples[t] = int_val\n", | |
" \n", | |
" num_accepted+= is_accepted_val\n", | |
" loss_history[t] = loss_val\n", | |
" if t % 1000 == 0 :\n", | |
" print(\"Iteration: {} Loss: {:.3f},\\\n", | |
" ln_TFR:{:.3f}, ln_pop{:.3f} intercept:{:.3f}\".format(\n", | |
" t, \\\n", | |
" loss_val, ln_tfr_val, ln_pop_val, int_val))\n", | |
"\n", | |
" ## Get difference of params and losses\n", | |
" if t > 1:\n", | |
" diff_norm = np.linalg.norm(np.subtract([effect_iso3_samples[t,:], effect_ln_TFR_samples[t], \\\n", | |
" effect_ln_pop_samples[t], \n", | |
" effect_int_samples[t]],\n", | |
" [effect_iso3_samples[t-1,:], effect_ln_TFR_samples[t-1], \\\n", | |
" effect_ln_pop_samples[t-1], \n", | |
" effect_int_samples[t-1]] ))\n", | |
"\n", | |
" loss_diff = np.abs(loss_val - loss_history[t-1])\n", | |
"\n", | |
" if np.mean(diff_norm) < TOL_PARAM:\n", | |
" print('Parameter convergence in {} iterations!'.format(t))\n", | |
" break\n", | |
"\n", | |
" if loss_diff < TOL_LOSS:\n", | |
" print('Loss function convergence in {} iterations!'.format(t))\n", | |
" break \n", | |
"\n", | |
" if t == MAX_ITER:\n", | |
" print('Max number of iterations reached without convergence.')\n", | |
" break\n", | |
" \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## These are supposed to be helpful for getting posterior intervals... TBD\n", | |
"\n", | |
"def interceptor(rv_constructor, *rv_args, **rv_kwargs):\n", | |
" \"\"\"Replaces prior on effects with empirical posterior mean from MCMC.\"\"\"\n", | |
" name = rv_kwargs.pop(\"name\")\n", | |
" if name == \"effect_ln_TFR\":\n", | |
" rv_kwargs[\"value\"] = effect_ln_TFR\n", | |
" elif name == \"effect_ln_pop\":\n", | |
" rv_kwargs[\"value\"] = effect_ln_pop\n", | |
" elif name == \"intercept\":\n", | |
" rv_kwargs[\"value\"] = intercept\n", | |
" elif name == \"effect_iso3\":\n", | |
" rv_kwargs[\"value\"] = effect_iso3\n", | |
" return rv_constructor(*rv_args, **rv_kwargs)\n", | |
"\n", | |
"with ed.interception(interceptor):\n", | |
" ratings_posterior = model_template(features=features_train, )\n", | |
" ratings_prediction = ratings_posterior[1].distribution.sample(100)\n", | |
"\n", | |
"with sess.as_default():\n", | |
" output_test=ratings_prediction.eval()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"output_test" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment