Last active
April 7, 2020 18:15
-
-
Save mike-lawrence/3e33139f256aa4f84772e72fb7195577 to your computer and use it in GitHub Desktop.
GP Regression example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
functions{ | |
// GP: computes noiseless Gaussian Process | |
vector GP(real volatility, real amplitude, vector normal01, int n_x, real[] x ) { | |
matrix[n_x,n_x] cov_mat ; | |
real amplitude_sq_plus_jitter ; | |
amplitude_sq_plus_jitter = amplitude^2 + 1e-6 ; | |
cov_mat = cov_exp_quad(x, amplitude, 1/volatility) ; | |
for(i in 1:n_x){ | |
cov_mat[i,i] = amplitude_sq_plus_jitter ; | |
} | |
return(cholesky_decompose(cov_mat) * normal01 ) ; | |
} | |
} | |
data { | |
// n_y: number of observations in y | |
int n_y ; | |
// y: vector of observations for y | |
// should be scaled to mean=0,sd=1 | |
vector[n_y] y ; | |
// n_x: number of unique x values | |
int n_x ; | |
// x: unique values of x | |
// should be scaled to min=0,max=1 | |
real x[n_x] ; | |
// x_index: vector indicating which x is associated zith each y | |
int x_index[n_y] ; | |
// n_z: number of columns in predictor matrix z | |
int n_z ; | |
// rows_z_unique: number of rows in predictor matrix z | |
int rows_z_unique ; | |
// z_unique: predictor matrix (each column gets its own GP) | |
matrix[rows_z_unique,n_z] z_unique ; | |
// z_by_f_index: | |
int z_by_f_index[n_y] ; | |
} | |
transformed data{ | |
matrix[n_z,rows_z_unique] tz = transpose(z_unique); | |
} | |
parameters { | |
// noise: measurement noise | |
real<lower=0> noise ; | |
// volatility_helper: helper for cauchy-distributed volatility (see transformed parameters) | |
vector<lower=0,upper=pi()/2>[n_z] volatility_helper ; | |
// amplitude: amplitude of GPs | |
vector<lower=0>[n_z] amplitude ; | |
// f_normal01: helper variable for GPs (see transformed parameters) | |
matrix[n_x,n_z] f_normal01 ; | |
} | |
transformed parameters{ | |
// volatility: volatility of GPs (a.k.a. inverse-lengthscale) | |
vector[n_z] volatility ; | |
// f: GPs | |
matrix[n_x,n_z] f ; | |
//next line implies volatility ~ cauchy(0,10) | |
volatility = 10*tan(volatility_helper) ; | |
// loop over predictors, computing GPs for each predictor | |
for(zi in 1:n_z){ | |
f[,zi] = GP( | |
volatility[zi] | |
, amplitude[zi] | |
, f_normal01[,zi] | |
, n_x , x | |
) ; | |
} | |
} | |
model { | |
// noise prior | |
noise ~ weibull(2,1) ; //peaked at .8ish | |
// amplitude prior | |
amplitude ~ weibull(2,1) ; //peaked at .8ish | |
// normal(0,1) priors on GP helpers | |
to_vector(f_normal01) ~ normal(0,1); | |
// loop over observations | |
y ~ normal( | |
to_vector(f*tz)[z_by_f_index] | |
, noise | |
); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# load packages | |
library(tidyverse) | |
library(rstan) | |
rstan_options(auto_write = TRUE) | |
# load ezStan (if you don't have it, install via: devtools::install_github('mike-lawrence/ezStan') ) | |
library(ezStan) | |
#ezStan has some nice functions for starting & watching parallel chains, | |
# as well as a nicer summary table of the posterior samples | |
# Make some fake data ---- | |
# n_x: number of unique samples on x-axis | |
n_x = 100 | |
# n_x: number of repeated observations per-x per-condition | |
n_reps = 10 | |
# prep a tibble with combination of x, conditions & reps | |
dat = as_tibble(expand.grid( | |
x = seq(-10,10,length.out=n_x) | |
, rep = 1:n_reps | |
, condition = c(-.5,.5) | |
)) | |
# set random seed for reproducibility | |
set.seed(1) | |
# add some columns, eventually leading to observed data | |
dat %>% | |
dplyr::mutate( | |
intercept = sin(x)*dnorm(x,5,8) #arbitrary wiggly curve | |
, effect = dnorm(x-5,2)/5 #ditto | |
, true = intercept + effect*condition #combined | |
, obs = scale(true)[,1] + rnorm(n(),0,1) #true plus noise | |
) -> | |
dat | |
# show the intercept function | |
dat %>% | |
ggplot( | |
mapping = aes( | |
x = x | |
, y = intercept | |
) | |
)+ | |
geom_line() | |
# show the effect function | |
dat %>% | |
ggplot( | |
mapping = aes( | |
x = x | |
, y = effect | |
) | |
)+ | |
geom_line() | |
# show the condition functions | |
dat %>% | |
ggplot( | |
mapping = aes( | |
x = x | |
, y = true | |
, colour = factor(condition) | |
, group = factor(condition) | |
) | |
)+ | |
geom_line() | |
# show the noisy observations | |
dat %>% | |
ggplot( | |
mapping = aes( | |
x = x | |
, y = obs | |
, group = rep | |
) | |
)+ | |
geom_line(alpha = .1)+ | |
facet_grid( | |
condition ~ . | |
) | |
# show the noisy observations collapsed to means | |
dat %>% | |
dplyr::group_by( | |
x | |
, condition | |
) %>% | |
dplyr::summarise( | |
m = mean(obs) | |
) %>% | |
ggplot( | |
mapping = aes( | |
x = x | |
, y = m | |
, colour = factor(condition) | |
, group = factor(condition) | |
) | |
)+ | |
geom_line() | |
# get rid of columns we wouldn't actually have for real data | |
dat %>% | |
dplyr::select( | |
-true | |
, -intercept | |
, -effect | |
) -> | |
dat | |
# Prep the data for stan ---- | |
# get the sorted unique value for x | |
x = sort(unique(dat$x)) | |
# for each value in dat$x, get its index x | |
x_index = match(dat$x,x) | |
# compute the model matrix | |
z = model.matrix( | |
data = dat | |
, object = ~ condition | |
) | |
# compute the unique entries in the model matrix | |
temp = as.data.frame(z) | |
temp = tidyr::unite_(data = temp, col = 'combined', from = names(temp)) | |
temp_unique = unique(temp) | |
z_unique = z[row.names(z)%in%row.names(temp_unique),] | |
# for each row in z, get its index z_unique | |
z_unique_index = match(temp$combined,temp_unique$combined) | |
# combine the two index objects to get the index into the flattened z_by_f vector | |
z_by_f_index = x_index + (z_unique_index-1)*length(x) | |
# create the data list for stan | |
data_for_stan = list( | |
n_y = nrow(dat) | |
, y = scale(dat$obs)[,1] #scaled to mean=0,sd=1 | |
, n_x = length(x) | |
, x = (x-min(x))/(max(x)-min(x)) #scaled to min=0,max=1 | |
, x_index = x_index | |
, n_z = ncol(z) | |
, rows_z_unique = nrow(z_unique) | |
, z_unique = z_unique | |
, z_by_f_index = z_by_f_index | |
) | |
# model ---- | |
#compile | |
gp_regression_mod = rstan::stan_model('gp_regression.stan') | |
# start the parallel chains | |
ezStan::start_stan( | |
mod = gp_regression_mod | |
, data = data_for_stan | |
, include = FALSE | |
, pars = c('f_normal01','volatility_helper') | |
, control = list( | |
adapt_delta = .99 #GPs tend to need higher-than-default adapt_delta | |
) | |
) | |
#watch the chains' progress | |
ezStan::watch_stan() | |
# collect results | |
post = collect_stan() | |
# kill just in case | |
ezStan::kill_stan() | |
# delete temp folder | |
ezStan::clean_stan() | |
#how long did it take? | |
sort(rowSums(get_elapsed_time(post)/60)) | |
#check noise & GP parameters | |
ezStan::stan_summary( | |
from_stan = post | |
, par = c('noise','volatility','amplitude') | |
) | |
#check the rhats for the latent functions | |
fstats = ezStan::stan_summary( | |
from_stan = post | |
, par = 'f' | |
, return_array = TRUE | |
) | |
summary(fstats[,ncol(fstats)]) #rhats | |
#visualize latent functions | |
f = rstan::extract( | |
post | |
, pars = 'f' | |
)[[1]] | |
f2 = tibble::as_tibble(data.frame(matrix( | |
f | |
, byrow = F | |
, nrow = dim(f)[1] | |
, ncol = dim(f)[2]*dim(f)[3] | |
))) | |
f2$sample = 1:nrow(f2) | |
f2 %>% | |
tidyr::gather( | |
key = 'key' | |
, value = 'value' | |
, -sample | |
) %>% #View() | |
dplyr::mutate( | |
key = as.numeric(gsub('X','',key)) | |
) %>% #-> temp | |
dplyr::mutate( | |
key = as.numeric(gsub('X','',key)) | |
, parameter = rep( | |
1:dim(f)[3] | |
, each = dim(f)[1]*dim(f)[2] | |
) | |
, x = rep(x,each=dim(f)[1],times=dim(f)[3]) | |
) %>% | |
dplyr::select( | |
-key | |
) -> | |
fdat | |
fdat %>% | |
dplyr::group_by( | |
x | |
, parameter | |
) %>% | |
dplyr::summarise( | |
med = median(value) | |
, lo95 = quantile(value,.025) | |
, hi95 = quantile(value,.975) | |
, lo50 = quantile(value,.25) | |
, hi50 = quantile(value,.75) | |
) %>% | |
ggplot()+ | |
geom_hline(yintercept=0)+ | |
geom_ribbon( | |
mapping = aes( | |
x = x | |
, ymin = lo95 | |
, ymax = hi95 | |
) | |
, alpha = .5 | |
)+ | |
geom_ribbon( | |
mapping = aes( | |
x = x | |
, ymin = lo50 | |
, ymax = hi50 | |
) | |
, alpha = .5 | |
)+ | |
geom_line( | |
mapping = aes( | |
x = x | |
, y = med | |
) | |
, alpha = .5 | |
)+ | |
facet_grid( | |
parameter ~ . | |
, scale = 'free_y' | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment