Created
July 8, 2017 00:16
-
-
Save namoshizun/a8727d34241353f271684a85e0554c5e to your computer and use it in GitHub Desktop.
simple gaussian process
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implement a very simple gaussian process for regression task. | |
Credit : http://katbailey.github.io/post/gaussian-processes-for-dummies/ | |
""" | |
import numpy as np | |
import matplotlib.pyplot as pl | |
def prepare_data(n, fn): | |
x = np.linspace(-5, 5, n).reshape(-1, 1) | |
return x, np.sin(x) | |
def kernel(a, b, theta=0.1): | |
# looks like a variance of ARD kernel | |
sqdist = np.sum(a**2,1).reshape(-1,1) + np.sum(b**2,1) - 2*np.dot(a, b.T) | |
return np.exp(-.5 * (1/theta) * sqdist) | |
# discretise the function domain | |
N = 50 | |
domain = np.linspace(-5, 5, N).reshape(-1,1) | |
# define the kernel function | |
K_ss = kernel(domain, domain) | |
# noiseless training data | |
n = 10 | |
x_train, y_train = prepare_data(n, fn=np.sin) | |
# apply the kernel function to our training points | |
K = kernel(x_train, x_train) | |
L = np.linalg.cholesky(K + 0.00005 * np.eye(n)) | |
# compute the mean at each domain point | |
K_s = kernel(x_train, domain) | |
Lk = np.linalg.solve(L, K_s) | |
mu = np.dot(Lk.T, np.linalg.solve(L, y_train)).reshape((N,)) | |
# compute the stdv so we can plot it | |
s2 = np.diag(K_ss) - np.sum(Lk**2, axis=0) | |
stdv = np.sqrt(s2) | |
# draw samples from the posterior at our domain points. | |
L = np.linalg.cholesky(K_ss + 1e-6*np.eye(N) - np.dot(Lk.T, Lk)) | |
f_post = mu.reshape(-1,1) + np.dot(L, np.random.normal(size=(N, 3))) | |
pl.plot(x_train, y_train, 'bs', ms=8) | |
pl.plot(domain, f_post) | |
pl.gca().fill_between(domain.flat, mu-2*stdv, mu+2*stdv, color="#dddddd") | |
pl.plot(domain, mu, 'r--', lw=2) | |
pl.axis([-5, 5, -3, 3]) | |
pl.title('Three samples from the GP posterior') | |
pl.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment