Skip to content

Instantly share code, notes, and snippets.

@sadatnfs
Last active April 17, 2018 22:02
Show Gist options
  • Save sadatnfs/c60d1d6eab664620f38c2cf44a0dfc24 to your computer and use it in GitHub Desktop.
Save sadatnfs/c60d1d6eab664620f38c2cf44a0dfc24 to your computer and use it in GitHub Desktop.
Fit a multidimensional GPR on a coregionalization based kernel (using health financing data)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Purpose: Creating a multidimensional correlated GPR\n",
"### Created: April 2017"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
}
],
"source": [
"import gpflow as gpflow\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"plt.style.use('ggplot')\n",
"%matplotlib inline\n",
"import seaborn as sns\n",
"np.random.seed(1)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>location_id</th>\n",
" <th>year_id</th>\n",
" <th>iso3</th>\n",
" <th>ghespc</th>\n",
" <th>cv_ln_ggepc</th>\n",
" <th>variance_ghespc</th>\n",
" <th>ln_ldipc</th>\n",
" <th>ppppc</th>\n",
" <th>variance_ppppc</th>\n",
" <th>ooppc</th>\n",
" <th>...</th>\n",
" <th>log_ppppc_st</th>\n",
" <th>ppppc_st</th>\n",
" <th>log_ooppc_st</th>\n",
" <th>ooppc_st</th>\n",
" <th>log_ghespc</th>\n",
" <th>log_ppppc</th>\n",
" <th>log_ooppc</th>\n",
" <th>log_ghespc_stage1</th>\n",
" <th>log_ppppc_stage1</th>\n",
" <th>log_ooppc_stage1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6</td>\n",
" <td>1990</td>\n",
" <td>CHN</td>\n",
" <td>NaN</td>\n",
" <td>5.812140</td>\n",
" <td>NaN</td>\n",
" <td>7.448037</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>2.651663</td>\n",
" <td>14.177595</td>\n",
" <td>3.892276</td>\n",
" <td>49.022319</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>3.440668</td>\n",
" <td>2.137880</td>\n",
" <td>4.162195</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>6</td>\n",
" <td>1991</td>\n",
" <td>CHN</td>\n",
" <td>NaN</td>\n",
" <td>5.875323</td>\n",
" <td>NaN</td>\n",
" <td>7.496622</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>2.710601</td>\n",
" <td>15.038309</td>\n",
" <td>3.924484</td>\n",
" <td>50.626944</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>3.506140</td>\n",
" <td>2.212852</td>\n",
" <td>4.211614</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>6</td>\n",
" <td>1992</td>\n",
" <td>CHN</td>\n",
" <td>NaN</td>\n",
" <td>5.975627</td>\n",
" <td>NaN</td>\n",
" <td>7.553911</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>2.754407</td>\n",
" <td>15.711716</td>\n",
" <td>3.954717</td>\n",
" <td>52.180940</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>3.598402</td>\n",
" <td>2.272888</td>\n",
" <td>4.260822</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>6</td>\n",
" <td>1993</td>\n",
" <td>CHN</td>\n",
" <td>NaN</td>\n",
" <td>6.077180</td>\n",
" <td>NaN</td>\n",
" <td>7.619321</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>2.812618</td>\n",
" <td>16.653468</td>\n",
" <td>3.993309</td>\n",
" <td>54.234040</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>3.697574</td>\n",
" <td>2.349654</td>\n",
" <td>4.315554</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>6</td>\n",
" <td>1994</td>\n",
" <td>CHN</td>\n",
" <td>NaN</td>\n",
" <td>6.178471</td>\n",
" <td>NaN</td>\n",
" <td>7.690529</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>...</td>\n",
" <td>2.882325</td>\n",
" <td>17.855745</td>\n",
" <td>4.038214</td>\n",
" <td>56.724969</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>3.801199</td>\n",
" <td>2.438130</td>\n",
" <td>4.372979</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 31 columns</p>\n",
"</div>"
],
"text/plain": [
" location_id year_id iso3 ghespc cv_ln_ggepc variance_ghespc ln_ldipc \\\n",
"0 6 1990 CHN NaN 5.812140 NaN 7.448037 \n",
"1 6 1991 CHN NaN 5.875323 NaN 7.496622 \n",
"2 6 1992 CHN NaN 5.975627 NaN 7.553911 \n",
"3 6 1993 CHN NaN 6.077180 NaN 7.619321 \n",
"4 6 1994 CHN NaN 6.178471 NaN 7.690529 \n",
"\n",
" ppppc variance_ppppc ooppc ... log_ppppc_st ppppc_st \\\n",
"0 NaN NaN NaN ... 2.651663 14.177595 \n",
"1 NaN NaN NaN ... 2.710601 15.038309 \n",
"2 NaN NaN NaN ... 2.754407 15.711716 \n",
"3 NaN NaN NaN ... 2.812618 16.653468 \n",
"4 NaN NaN NaN ... 2.882325 17.855745 \n",
"\n",
" log_ooppc_st ooppc_st log_ghespc log_ppppc log_ooppc log_ghespc_stage1 \\\n",
"0 3.892276 49.022319 NaN NaN NaN 3.440668 \n",
"1 3.924484 50.626944 NaN NaN NaN 3.506140 \n",
"2 3.954717 52.180940 NaN NaN NaN 3.598402 \n",
"3 3.993309 54.234040 NaN NaN NaN 3.697574 \n",
"4 4.038214 56.724969 NaN NaN NaN 3.801199 \n",
"\n",
" log_ppppc_stage1 log_ooppc_stage1 \n",
"0 2.137880 4.162195 \n",
"1 2.212852 4.211614 \n",
"2 2.272888 4.260822 \n",
"3 2.349654 4.315554 \n",
"4 2.438130 4.372979 \n",
"\n",
"[5 rows x 31 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"## Get the full set of data with a first stage linear fit\n",
"data_raw = pd.read_csv('/home/j/temp/sadatnfs/test_gpr_HEs.csv')\n",
"data_raw.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"## Prep our data (sub out one country maybe)\n",
"iso = 'JPN'\n",
"data_raw['year'] = data_raw['year_id']\n",
"data_use = data_raw.query(\"iso3 == '{}'\".format(iso)).set_index(['location_id', 'iso3', 'year_id'])\n",
"data_use.head()\n",
"\n",
"## If all of the data is missing from a column, use the ST fit for now (all() checks if all are True)\n",
"ghes_na = data_use['log_ghespc'].isnull().all()\n",
"ppp_na = data_use['log_ppppc'].isnull().all()\n",
"oop_na = data_use['log_ooppc'].isnull().all()\n",
"\n",
"if(ghes_na):\n",
" data_use['log_ghespc'] = data_use['log_ghespc_stage1']\n",
"if(ppp_na):\n",
" data_use['log_ppppc'] = data_use['log_ppppc_stage1']\n",
"if(oop_na):\n",
" data_use['log_ooppc'] = data_use['log_ooppc_stage1']"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"## Simple plotting function\n",
"def plot_gp(x, mu, var, **kwargs):\n",
" plt.plot(x, mu, lw=2, **kwargs)\n",
" plt.fill_between(x, mu - 2*np.sqrt(var), mu + 2*np.sqrt(var), alpha=0.2, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"## Prediction variables\n",
"yvars = ['log_ghespc', 'log_ppppc', 'log_ooppc']\n",
"\n",
"## Input data\n",
"xvars = ['year', 'cv_ln_ggepc', 'ln_ldipc', 'ln_tfr', 'logit_haq']"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"## Prep the training dataset \n",
"Y1 = data_use.loc[(data_use.log_ghespc.notnull())][yvars[0]].values\n",
"Y2 = data_use.loc[(data_use.log_ppppc.notnull())][yvars[1]].values\n",
"Y3 = data_use.loc[(data_use.log_ooppc.notnull())][yvars[2]].values\n",
"\n",
"X1 = data_use.loc[(data_use.ghespc.notnull())][xvars].values.astype(float)\n",
"X2 = data_use.loc[(data_use.ppppc.notnull())][xvars].values.astype(float)\n",
"X3 = data_use.loc[(data_use.ooppc.notnull())][xvars].values.astype(float)\n",
"\n",
"X_full = data_use[xvars].values.astype(float)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Augment the data with 0/1/2s to indicate different outputs\n",
"\n",
"X_augmented = np.vstack((np.hstack((X1, np.zeros_like(X1) )), \\\n",
" np.hstack((X2, np.ones_like(X2) )) , \\\n",
" np.hstack((X3, np.ones_like(X3) *2))) )\n",
"\n",
"Y_augmented = np.vstack((np.vstack((Y1, np.zeros_like(Y1))).T, \\\n",
" np.vstack((Y2, np.ones_like(Y2))).T , \\\n",
" np.vstack((Y3, np.ones_like(Y3)*2)).T ))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"### Set up our base kernel: we will have a Matern3/2 kernel across time, and linear across other inputs\n",
"k1 = gpflow.kernels.Matern32(input_dim = 1, active_dims=[0])\n",
"k2 = gpflow.kernels.Linear(input_dim = 4, active_dims=[1,2,3,4])\n",
"# k1 = gpflow.kernels.Matern32(input_dim = 2, active_dims=[0,1])\n",
"# k2 = gpflow.kernels.Linear(input_dim = 3, active_dims=[2,3,4])\n",
"\n",
"### Set up our coregionalization kernel: \n",
"# The 'coregion' kernel indexes the outputs, and acts on the last dimension\n",
"coreg = gpflow.kernels.Coregion(1, output_dim=3, rank=3, active_dims=[5])\n",
"\n",
"## Our main kernel will be a kronecker product of the base \n",
"kern = (k1+k2) * coreg\n",
"\n",
"## Define our likelihood function, which will be a combination of 3 Gaussian likelihoods for each input\n",
"lik = gpflow.likelihoods.SwitchedLikelihood([\n",
" gpflow.likelihoods.Gaussian(), \\\n",
" gpflow.likelihoods.Gaussian(), \\\n",
" gpflow.likelihoods.Gaussian()\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"## A quick clean up assurance\n",
"try:\n",
" del(m)\n",
"except Exception:\n",
" pass\n",
"\n",
"## Build the Variational GP model\n",
"m = gpflow.models.VGP(X_augmented, Y_augmented, kern=kern, likelihood=lik)\n",
"\n",
"if hasattr(m.kern, 'coregion'): \n",
" ## Re-initialize kernel's coregion matrix with random entries to avoid saddle-point and re-optimize\n",
" rand_mat = np.random.randn(3, 3)\n",
" np.fill_diagonal(rand_mat, 1.0)\n",
" m.kern.coregion.W = rand_mat\n",
"\n",
"\n",
"## Do we want to fix some parameters?\n",
"m.kern.sum.matern32.lengthscales = 3\n",
"m.kern.sum.matern32.lengthscales.trainable = False\n",
"\n",
"m.kern.coregion.kappa = np.array([1.,1.,1.])\n",
"m.kern.coregion.kappa.trainable = False\n",
"\n",
"# m.kern.coregion.W.trainable = False\n",
"# m.kern.matern32.variance.trainable = True\n",
"\n",
"## Compile and print our intitial model setup\n",
"m.compile()\n",
"print(\"\\n Initializing our model: \\n\")\n",
"m.as_pandas_table()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Train our GPR Model and print the final parameters\n",
"gpflow.train.ScipyOptimizer().minimize(m, maxiter=5000)\n",
"m.as_pandas_table()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Predict for each of the 3 outputs\n",
"mu1, var1 = m.predict_y(np.hstack((X_full, np.zeros_like(X_full))))\n",
"mu2, var2 = m.predict_y(np.hstack((X_full, np.ones_like(X_full))))\n",
"mu3, var3 = m.predict_y(np.hstack((X_full, np.ones_like(X_full)*2)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Plot stuff\n",
"\n",
"colorfill = sns.color_palette('Greens')\n",
"plt.rcParams[\"figure.figsize\"] =(12,6) \n",
"plt.subplot(131)\n",
"line, = plt.plot(X1[:,0], Y1, 'x', mew=2, color=colorfill[5])\n",
"plt.title(\"Log GHESpc\")\n",
"plot_gp(X_full[:,0], mu1[:,0], var1[:,0], color=colorfill[4] )\n",
"\n",
"plt.subplot(132)\n",
"line, = plt.plot(X2[:,0], Y2, 'x', mew=2, color=colorfill[5])\n",
"plt.title(\"Log PPPpc\")\n",
"plot_gp(X_full[:,0], mu2[:,0], var2[:,0], color=colorfill[4] )\n",
"\n",
"plt.subplot(133)\n",
"line, = plt.plot(X3[:,0], Y3, 'x', mew=2, color=colorfill[5])\n",
"plt.title(\"Log OOPpc\")\n",
"plot_gp(X_full[:,0], mu3[:,0], var3[:,0], color=colorfill[4])\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment