Last active
September 15, 2018 18:52
-
-
Save jotterbach/2ba82eb73b8071bdb2968da2a423ca5f to your computer and use it in GitHub Desktop.
Snippet to demonstrate the conditioning procedure of a Gaussian Process and how to plot it.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import seaborn as sns | |
| import numpy as np | |
| import numpy.random as rd | |
| import matplotlib.pyplot as plt | |
| from scipy.spatial import distance as dist | |
| import scipy.linalg as la | |
| import itertools as it | |
| %matplotlib inline | |
| sns.set_style('whitegrid', {"axes.facecolor": "1.0"}) | |
| def kernel(x0, x1, scale=1.0): | |
| return np.exp(-.5 * dist.euclidean(x0, x1)**2/scale) | |
| def calculate_kernel_matrix(x, y, scale=1.0): | |
| k = np.zeros(shape = (len(x), len(y))) | |
| for i, x1 in enumerate(x): | |
| for j, y1 in enumerate(y): | |
| k[i, j] = kernel(x1, y1, scale=scale) | |
| return k | |
| def get_updated_mean_variance(k, ks, kss, obs_val): | |
| inv_k = np.linalg.inv(k) | |
| new_mean = ks.dot(inv_k).dot(obs_val) | |
| new_k = kss - ks.dot(inv_k).dot(ks.transpose()) | |
| return new_mean, new_k | |
| def plot_gaussian_process(observations, evaluation_pts, kernel_scale=1.0): | |
| eval_kernel = calculate_kernel_matrix(evaluation_pts, evaluation_pts, scale=kernel_scale) | |
| if len(observations) > 0: | |
| x = [p[0] for p in observations] | |
| fx = [p[1] for p in observations] | |
| obs_kernel = calculate_kernel_matrix(x, x, scale=kernel_scale) | |
| eval_obs_kernel = calculate_kernel_matrix(evaluation_pts, x, scale=kernel_scale) | |
| mu, sigma = get_updated_mean_variance(obs_kernel, eval_obs_kernel, eval_kernel, fx) | |
| else: | |
| mu = np.zeros(len(evaluation_pts)) | |
| sigma = eval_kernel | |
| fig = plt.figure(figsize=(12, 8)) | |
| for cnt in range(3): | |
| plt.plot(evaluation_pts, rd.multivariate_normal(mu, sigma), label="process sample {}".format(cnt+1)) | |
| plt.fill_between(evaluation_pts, | |
| mu + 2 * np.sqrt(np.diag(sigma)), | |
| mu - 2 * np.sqrt(np.diag(sigma)), | |
| alpha=0.3, | |
| label=r"$2\,\sigma$") | |
| plt.plot(evaluation_pts, mu, 'k--', label='mean prediction') | |
| if len(observations)>0: | |
| plt.scatter(x, fx, color='k', s=300, marker='+', label="observations") | |
| plt.tight_layout() | |
| plt.xticks(fontsize=18) | |
| plt.yticks(fontsize=18) | |
| plt.xlabel(r"input $\theta$", fontsize=22) | |
| plt.ylabel(r"output $f(\theta)$", fontsize=22) | |
| plt.legend(loc='best', fontsize=16) | |
| obsv = [ | |
| (-4, 1.5), | |
| (-3.8, 1), | |
| (-1.3, -.25), | |
| (2.6, 1.3), | |
| (3.5, .15) | |
| ] | |
| plot_gaussian_process(obsv, np.linspace(-5, 5, 100), kernel_scale=1) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The code produces a plot like: