Skip to content

Instantly share code, notes, and snippets.

@tjmahr
Forked from jkeirstead/gp-predict.stan
Created June 15, 2017 13:36
Show Gist options
  • Save tjmahr/329271d16cc3ff95fb9c82be5768b4ab to your computer and use it in GitHub Desktop.
Save tjmahr/329271d16cc3ff95fb9c82be5768b4ab to your computer and use it in GitHub Desktop.
A demo of Gaussian processes using RStan
// Predict from Gaussian Process
// All data parameters must be passed as a list to the Stan call
// Based on original file from https://code.google.com/p/stan/source/browse/src/models/misc/gaussian-process/
data {
int<lower=1> N1;
vector[N1] x1;
vector[N1] y1;
int<lower=1> N2;
vector[N2] x2;
real sigma_sq;
real eta_sq;
real rho_sq;
}
transformed data {
int<lower=1> N;
vector[N1+N2] x;
vector[N1+N2] mu;
cov_matrix[N1+N2] Sigma;
N <- N1 + N2;
for (n in 1:N1) x[n] <- x1[n];
for (n in 1:N2) x[N1 + n] <- x2[n];
for (i in 1:N) mu[i] <- 0;
for (i in 1:N)
for (j in 1:N)
Sigma[i,j] <- eta_sq*exp(-rho_sq*pow(x[i] - x[j],2))
+ if_else(i==j, sigma_sq, 0.0);
}
parameters {
vector[N2] y2;
}
model {
vector[N] y;
for (n in 1:N1) y[n] <- y1[n];
for (n in 1:N2) y[N1 + n] <- y2[n];
y ~ multi_normal(mu,Sigma);
}
// Sample from Gaussian process
// All data parameters must be passed as a list to the Stan call
// Based on original file from https://code.google.com/p/stan/source/browse/src/models/misc/gaussian-process/
data {
int<lower=1> N;
real x[N];
real eta_sq;
real rho_sq;
real sigma_sq;
}
transformed data {
vector[N] mu;
cov_matrix[N] Sigma;
for (i in 1:N)
mu[i] <- 0;
for (i in 1:N)
for (j in 1:N)
Sigma[i,j] <- eta_sq * exp(-rho_sq*pow(x[i] - x[j],2)) + if_else(i==j, sigma_sq, 0.0);
}
parameters {
vector[N] y;
}
model {
y ~ multi_normal(mu,Sigma);
}
## Gaussian Process Regression with RStan
## James Keirstead
## 19 August 2013
##
## This is based on the examples in Rasmussen and William's Gaussian Processes book.
## See http://www.jameskeirstead.ca/blog/gaussian-process-regression-with-r/ for the long-hand version
## load the required packages
require(rstan)
require(plyr)
require(ggplot2)
## 1. Simulate a process with no data
## The very small sigma_sq value is necessary to avoid an error. Don't set it to zero.
x <- seq(-5, 5, 0.2)
n <- length(x)
fit <- stan(file="gp-sim.stan", data=list(x=x, N=n, eta_sq=1, rho_sq=0.5, sigma_sq=0.0001),
iter=200, chains=3)
sims <- extract(fit, permuted=TRUE)
## Rearrange the data and plot it
data <- adply(sims$y, 2)
tmp <- melt(data)
names(tmp) <- c("xid", "group", "y")
tmp <- mutate(tmp, x=x[xid])
fig2a <- ggplot(tmp, aes(x=x, y=y)) +
geom_line(aes(group=group), colour="#999999", alpha=0.3) +
theme_bw()
## 2. Simulate with a few noise-free data points.
## Again pretend the noise is almost zero, but not quite.
x1 <- c(-4, -3, -1, 0, 2)
y1 <- c(-2, 0, 1, 2, -1)
## Parameter value fixed in given example
fit <- stan(file="gp-predict.stan", data=list(x1=x1, y1=y1, N1=length(x1),
x2=x, N2=length(x), eta_sq=1, rho_sq=0.5, sigma_sq=0.0001),
iter=200, chains=3)
sims <- extract(fit, permuted=TRUE)
## Rearrange the data and plot it
data <- adply(sims$y, 2)
tmp <- melt(data)
names(tmp) <- c("xid", "group", "y")
tmp <- mutate(tmp, x=x[xid])
fig2b <- ggplot(tmp, aes(x=x, y=y)) +
geom_line(aes(group=group), colour="#999999", alpha=0.3) +
theme_bw() +
geom_point(data=data.frame(x=x1, y=y1))
## 3. Adding more noise is easy. Just change sigma_sq
sigma.n <- 0.1
fit <- stan(file="gp-predict.stan", data=list(x1=x1, y1=y1, N1=length(x1),
x2=x, N2=length(x), eta_sq=1, rho_sq=1, sigma_sq=sigma.n^2),
iter=200, chains=3)
sims <- extract(fit, permuted=TRUE)
## Rearrange the data and plot
data <- adply(sims$y, 2)
tmp <- melt(data)
names(tmp) <- c("xid", "group", "y")
tmp <- mutate(tmp, x=x[xid])
fig2c <- ggplot(tmp, aes(x=x, y=y)) +
geom_line(aes(group=group), colour="#999999", alpha=0.3) +
theme_bw() +
geom_point(data=data.frame(x=x1, y=y1)) +
geom_errorbar(data=data.frame(x=x1, y=y1), aes(x=x,y=NULL,ymin=y-2*sigma.n, ymax=y+2*sigma.n), width=0.2)
## Save plots for the web
w <- 6
h <- 4
ggsave("fig2a-rstan.png", fig2a, width=w, height=hh, dpi=150)
ggsave("fig2b-rstan.png", fig2b, width=w, height=hh, dpi=150)
ggsave("fig2c-rstan.png", fig2c, width=w, height=hh, dpi=150)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment