Created
April 14, 2018 00:39
-
-
Save sadatnfs/11f5d8f63ddbabbf41afaff7d1fb39c6 to your computer and use it in GitHub Desktop.
Fit a Multiclass Gaussian Process Regression with a coregionalization kernel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Multiclass Gaussian Process Regression with a coregionalization kernel \n", | |
"\n", | |
"### Author: Nafis Sadat\n", | |
"### Updated: April 13, 2018\n", | |
"\n", | |
"#### Purpose: Simulate correlated data and try to fit on testing data on a coregion" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 99, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Bring in packages\n", | |
"\n", | |
"import gpflow as gpflow\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"import xarray as xr\n", | |
"import matplotlib.pyplot as plt\n", | |
"plt.style.use('ggplot')\n", | |
"%matplotlib inline\n", | |
"import seaborn as sns\n", | |
"np.random.seed(1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 102, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## A simple function to plot the mean and 95% CIs out of a GPR fit\n", | |
"\n", | |
"def plot_gp(x, mu, var, color='k'):\n", | |
" plt.plot(x, mu, color=color, lw=2)\n", | |
" plt.plot(x, mu + 2*np.sqrt(var), '--', color=color)\n", | |
" plt.plot(x, mu - 2*np.sqrt(var), '--', color=color)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Making our dataset:\n", | |
"\n", | |
"We will simulate our X-axis on a 0-1 scale, but only have a complete series for a single variable, and use that to model the partially completed series" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"## Create Xs\n", | |
"X1 = np.random.rand(100, 1)\n", | |
"X2 = np.random.rand(100, 1)*.40\n", | |
"X3 = np.random.rand(100, 1)*.65 \n", | |
"\n", | |
"## Create Ys\n", | |
"Y1 = np.sin(5*X1) + np.random.standard_t(3, X1.shape)*0.03\n", | |
"Y2 = np.sin(6*X2+ 0.7) + np.random.standard_t(3, X2.shape)*0.1\n", | |
"Y3 = np.cos(5*X3- 0.3) + np.random.standard_t(3, X3.shape)*0.2\n", | |
"\n", | |
"## Plot\n", | |
"plt.plot(X1, Y1, 'x', mew=2)\n", | |
"plt.plot(X2, Y2, 'x', mew=2)\n", | |
"plt.plot(X3, Y3, 'x', mew=2)\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# a Coregionalization kernel. The base kernel is Matern, and acts on the first ([0]) data dimension.\n", | |
"# the 'Coregion' kernel indexes the outputs, and actos on the second ([1]) data dimension\n", | |
"k1 = gpflow.kernels.Matern32(1, active_dims=[0])\n", | |
"coreg = gpflow.kernels.Coregion(1, output_dim=3, rank=3, active_dims=[1])\n", | |
"kern = k1 * coreg" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 75, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Augment the time data with 0/1/2s to indicate the required output dimensions\n", | |
"X_augmented = np.vstack((np.hstack((X1, np.zeros_like(X1))), np.hstack((X2, np.ones_like(X2))) , np.hstack((X3, np.ones_like(X3)*2))) )\n", | |
"\n", | |
"# Augment the Y data to indicate which likelihood we should use\n", | |
"Y_augmented = np.vstack((np.hstack((Y1, np.zeros_like(X1))), np.hstack((Y2, np.ones_like(X2))) , np.hstack((Y3, np.ones_like(X3)*2)) ))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Build a variational model using either a Student-T (for heavy tailed noises) or a simple Gaussian likelihood \n", | |
"# lik = gpflow.likelihoods.SwitchedLikelihood([\n", | |
"# gpflow.likelihoods.StudentT(), gpflow.likelihoods.StudentT(), gpflow.likelihoods.StudentT()\n", | |
"# ])\n", | |
"\n", | |
"lik = gpflow.likelihoods.SwitchedLikelihood([\n", | |
" gpflow.likelihoods.Gaussian(), gpflow.likelihoods.Gaussian(), gpflow.likelihoods.Gaussian()\n", | |
"])\n", | |
"\n", | |
"# Biuld the GP model\n", | |
"m = gpflow.models.VGP(X_augmented, Y_augmented, kern=kern, likelihood=lik, num_latent=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 78, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py:100: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n", | |
" \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Optimization terminated with:\n", | |
" Message: b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'\n", | |
" Objective function value: -171.027553\n", | |
" Number of iterations: 726\n", | |
" Number of functions evaluations: 783\n" | |
] | |
} | |
], | |
"source": [ | |
"# Fit the covariance function parameters\n", | |
"gpflow.train.ScipyOptimizer().minimize(m)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 79, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>class</th>\n", | |
" <th>prior</th>\n", | |
" <th>transform</th>\n", | |
" <th>trainable</th>\n", | |
" <th>shape</th>\n", | |
" <th>fixed_shape</th>\n", | |
" <th>value</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>VGP/kern/matern32/variance</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>()</td>\n", | |
" <td>True</td>\n", | |
" <td>1.0151562158765153</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/kern/matern32/lengthscales</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>()</td>\n", | |
" <td>True</td>\n", | |
" <td>0.6359447652190311</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/kern/coregion/W</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>(none)</td>\n", | |
" <td>True</td>\n", | |
" <td>(3, 3)</td>\n", | |
" <td>True</td>\n", | |
" <td>[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/kern/coregion/kappa</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>(3,)</td>\n", | |
" <td>True</td>\n", | |
" <td>[0.6637290565759171, 0.9459817096183656, 1.479...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/likelihood/likelihood_list/0/variance</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>()</td>\n", | |
" <td>True</td>\n", | |
" <td>0.0020894776323322315</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/likelihood/likelihood_list/1/variance</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>()</td>\n", | |
" <td>True</td>\n", | |
" <td>0.018949884240303434</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/likelihood/likelihood_list/2/variance</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>()</td>\n", | |
" <td>True</td>\n", | |
" <td>0.06031213649381869</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/q_mu</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>(none)</td>\n", | |
" <td>True</td>\n", | |
" <td>(300, 1)</td>\n", | |
" <td>True</td>\n", | |
" <td>[[1.095561708152853], [-1.535152371631651], [-...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/q_sqrt</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>LoTri->vec</td>\n", | |
" <td>True</td>\n", | |
" <td>(1, 300, 300)</td>\n", | |
" <td>True</td>\n", | |
" <td>[[[0.019321339978307588, 0.0, 0.0, 0.0, 0.0, 0...</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" class prior transform \\\n", | |
"VGP/kern/matern32/variance Parameter None +ve \n", | |
"VGP/kern/matern32/lengthscales Parameter None +ve \n", | |
"VGP/kern/coregion/W Parameter None (none) \n", | |
"VGP/kern/coregion/kappa Parameter None +ve \n", | |
"VGP/likelihood/likelihood_list/0/variance Parameter None +ve \n", | |
"VGP/likelihood/likelihood_list/1/variance Parameter None +ve \n", | |
"VGP/likelihood/likelihood_list/2/variance Parameter None +ve \n", | |
"VGP/q_mu Parameter None (none) \n", | |
"VGP/q_sqrt Parameter None LoTri->vec \n", | |
"\n", | |
" trainable shape \\\n", | |
"VGP/kern/matern32/variance True () \n", | |
"VGP/kern/matern32/lengthscales True () \n", | |
"VGP/kern/coregion/W True (3, 3) \n", | |
"VGP/kern/coregion/kappa True (3,) \n", | |
"VGP/likelihood/likelihood_list/0/variance True () \n", | |
"VGP/likelihood/likelihood_list/1/variance True () \n", | |
"VGP/likelihood/likelihood_list/2/variance True () \n", | |
"VGP/q_mu True (300, 1) \n", | |
"VGP/q_sqrt True (1, 300, 300) \n", | |
"\n", | |
" fixed_shape \\\n", | |
"VGP/kern/matern32/variance True \n", | |
"VGP/kern/matern32/lengthscales True \n", | |
"VGP/kern/coregion/W True \n", | |
"VGP/kern/coregion/kappa True \n", | |
"VGP/likelihood/likelihood_list/0/variance True \n", | |
"VGP/likelihood/likelihood_list/1/variance True \n", | |
"VGP/likelihood/likelihood_list/2/variance True \n", | |
"VGP/q_mu True \n", | |
"VGP/q_sqrt True \n", | |
"\n", | |
" value \n", | |
"VGP/kern/matern32/variance 1.0151562158765153 \n", | |
"VGP/kern/matern32/lengthscales 0.6359447652190311 \n", | |
"VGP/kern/coregion/W [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ... \n", | |
"VGP/kern/coregion/kappa [0.6637290565759171, 0.9459817096183656, 1.479... \n", | |
"VGP/likelihood/likelihood_list/0/variance 0.0020894776323322315 \n", | |
"VGP/likelihood/likelihood_list/1/variance 0.018949884240303434 \n", | |
"VGP/likelihood/likelihood_list/2/variance 0.06031213649381869 \n", | |
"VGP/q_mu [[1.095561708152853], [-1.535152371631651], [-... \n", | |
"VGP/q_sqrt [[[0.019321339978307588, 0.0, 0.0, 0.0, 0.0, 0... " | |
] | |
}, | |
"execution_count": 79, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"## Check out the fits\n", | |
"m.as_pandas_table()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Simulate Testing data\n", | |
"xtest = np.hstack([np.linspace(0, 1, 100)])[:,None]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 81, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Fit on xtest\n", | |
"mu1, var1 = m.predict_f(np.hstack((xtest, np.zeros_like(xtest))))\n", | |
"mu2, var2 = m.predict_f(np.hstack((xtest, np.ones_like(xtest))))\n", | |
"mu3, var3 = m.predict_f(np.hstack((xtest, 2*np.ones_like(xtest))))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 82, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"## Plot \n", | |
"line, = plt.plot(X1, Y1, 'x', mew=2)\n", | |
"plot_gp(xtest, mu1, var1, line.get_color())\n", | |
"\n", | |
"line, = plt.plot(X2, Y2, 'x', mew=2)\n", | |
"plot_gp(xtest, mu2, var2, line.get_color())\n", | |
"\n", | |
"line, = plt.plot(X3, Y3, 'x', mew=2)\n", | |
"plot_gp(xtest, mu3, var3, line.get_color())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### The model does a poor job of estimating the correlation between the 3 Ys for x>0.5. This is caused by a saddle-point in the coregion matrix in the objective function. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 93, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py:100: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n", | |
" \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Optimization terminated with:\n", | |
" Message: b'STOP: TOTAL NO. of ITERATIONS EXCEEDS LIMIT'\n", | |
" Objective function value: -180.778847\n", | |
" Number of iterations: 1001\n", | |
" Number of functions evaluations: 1089\n" | |
] | |
} | |
], | |
"source": [ | |
"## Re-initialize kernel matrix with random entries to avoid saddle-point and re-optimize\n", | |
"rand_mat = np.random.randn(3, 3)\n", | |
"np.fill_diagonal(rand_mat, 1.0)\n", | |
"\n", | |
"m.kern.coregion.W = rand_mat\n", | |
"gpflow.train.ScipyOptimizer().minimize(m)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 96, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>class</th>\n", | |
" <th>prior</th>\n", | |
" <th>transform</th>\n", | |
" <th>trainable</th>\n", | |
" <th>shape</th>\n", | |
" <th>fixed_shape</th>\n", | |
" <th>value</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>VGP/kern/matern32/variance</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>()</td>\n", | |
" <td>True</td>\n", | |
" <td>0.027341925687519704</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/kern/matern32/lengthscales</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>()</td>\n", | |
" <td>True</td>\n", | |
" <td>0.734572182397693</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/kern/coregion/W</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>(none)</td>\n", | |
" <td>True</td>\n", | |
" <td>(3, 3)</td>\n", | |
" <td>True</td>\n", | |
" <td>[[5.33029410916914, 1.3136740802790725, -1.569...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/kern/coregion/kappa</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>(3,)</td>\n", | |
" <td>True</td>\n", | |
" <td>[0.0709326447046722, 0.0036430587971050567, 0....</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/likelihood/likelihood_list/0/variance</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>()</td>\n", | |
" <td>True</td>\n", | |
" <td>0.002096954340679575</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/likelihood/likelihood_list/1/variance</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>()</td>\n", | |
" <td>True</td>\n", | |
" <td>0.018810843757694297</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/likelihood/likelihood_list/2/variance</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>+ve</td>\n", | |
" <td>True</td>\n", | |
" <td>()</td>\n", | |
" <td>True</td>\n", | |
" <td>0.05996604408304894</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/q_mu</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>(none)</td>\n", | |
" <td>True</td>\n", | |
" <td>(300, 1)</td>\n", | |
" <td>True</td>\n", | |
" <td>[[0.953441843214516], [-1.560486883391416], [-...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>VGP/q_sqrt</th>\n", | |
" <td>Parameter</td>\n", | |
" <td>None</td>\n", | |
" <td>LoTri->vec</td>\n", | |
" <td>True</td>\n", | |
" <td>(1, 300, 300)</td>\n", | |
" <td>True</td>\n", | |
" <td>[[[0.0162377753439612, 0.0, 0.0, 0.0, 0.0, 0.0...</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" class prior transform \\\n", | |
"VGP/kern/matern32/variance Parameter None +ve \n", | |
"VGP/kern/matern32/lengthscales Parameter None +ve \n", | |
"VGP/kern/coregion/W Parameter None (none) \n", | |
"VGP/kern/coregion/kappa Parameter None +ve \n", | |
"VGP/likelihood/likelihood_list/0/variance Parameter None +ve \n", | |
"VGP/likelihood/likelihood_list/1/variance Parameter None +ve \n", | |
"VGP/likelihood/likelihood_list/2/variance Parameter None +ve \n", | |
"VGP/q_mu Parameter None (none) \n", | |
"VGP/q_sqrt Parameter None LoTri->vec \n", | |
"\n", | |
" trainable shape \\\n", | |
"VGP/kern/matern32/variance True () \n", | |
"VGP/kern/matern32/lengthscales True () \n", | |
"VGP/kern/coregion/W True (3, 3) \n", | |
"VGP/kern/coregion/kappa True (3,) \n", | |
"VGP/likelihood/likelihood_list/0/variance True () \n", | |
"VGP/likelihood/likelihood_list/1/variance True () \n", | |
"VGP/likelihood/likelihood_list/2/variance True () \n", | |
"VGP/q_mu True (300, 1) \n", | |
"VGP/q_sqrt True (1, 300, 300) \n", | |
"\n", | |
" fixed_shape \\\n", | |
"VGP/kern/matern32/variance True \n", | |
"VGP/kern/matern32/lengthscales True \n", | |
"VGP/kern/coregion/W True \n", | |
"VGP/kern/coregion/kappa True \n", | |
"VGP/likelihood/likelihood_list/0/variance True \n", | |
"VGP/likelihood/likelihood_list/1/variance True \n", | |
"VGP/likelihood/likelihood_list/2/variance True \n", | |
"VGP/q_mu True \n", | |
"VGP/q_sqrt True \n", | |
"\n", | |
" value \n", | |
"VGP/kern/matern32/variance 0.027341925687519704 \n", | |
"VGP/kern/matern32/lengthscales 0.734572182397693 \n", | |
"VGP/kern/coregion/W [[5.33029410916914, 1.3136740802790725, -1.569... \n", | |
"VGP/kern/coregion/kappa [0.0709326447046722, 0.0036430587971050567, 0.... \n", | |
"VGP/likelihood/likelihood_list/0/variance 0.002096954340679575 \n", | |
"VGP/likelihood/likelihood_list/1/variance 0.018810843757694297 \n", | |
"VGP/likelihood/likelihood_list/2/variance 0.05996604408304894 \n", | |
"VGP/q_mu [[0.953441843214516], [-1.560486883391416], [-... \n", | |
"VGP/q_sqrt [[[0.0162377753439612, 0.0, 0.0, 0.0, 0.0, 0.0... " | |
] | |
}, | |
"execution_count": 96, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"## Check our posteriors again\n", | |
"m.as_pandas_table()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 97, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Refit on test data, with the new optimized parameters\n", | |
"\n", | |
"mu1, var1 = m.predict_f(np.hstack((xtest, np.zeros_like(xtest))))\n", | |
"mu2, var2 = m.predict_f(np.hstack((xtest, np.ones_like(xtest))))\n", | |
"mu3, var3 = m.predict_f(np.hstack((xtest, 2*np.ones_like(xtest))))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 101, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"line, = plt.plot(X1, Y1, 'x', mew=2)\n", | |
"plot_gp(xtest, mu1, var1, line.get_color())\n", | |
"\n", | |
"line, = plt.plot(X2, Y2, 'x', mew=2)\n", | |
"plot_gp(xtest, mu2, var2, line.get_color())\n", | |
"\n", | |
"line, = plt.plot(X3, Y3, 'x', mew=2)\n", | |
"plot_gp(xtest, mu3, var3, line.get_color())" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment