Last active
October 16, 2023 21:21
-
-
Save gavinsimpson/99741b2bfc7a0bbf300e55cd7110eee3 to your computer and use it in GitHub Desktop.
Extrapolating with penalised b-splines and mgcv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Can we do better extrapolation than the TPRS example used by Gabriel | |
## Riutort-Mayol et al in their recent paper on low rank approximations to | |
## Gaussian Processes? https://arxiv.org/abs/2004.11408 | |
## Packages | |
library('ggplot2') | |
library('tibble') | |
library('tidyr') | |
library('dplyr') | |
library('mgcv') | |
## remotes::install_github("clauswilke/colorblindr") | |
library('colorblindr') | |
## remotes::install_github("clauswilke/relayer") | |
library('relayer') | |
## Univariate example ----------------------------------------------------------- | |
## Load realisations from true function | |
load(url("https://github.com/gabriuma/basis_functions_approach_to_GP/raw/master/Paper/Case-study_1D-Simulated-data/r-code/f_true.rData")) | |
## Simulate noisy data from the function | |
seed <- 1234 | |
set.seed(seed) | |
## reference data | |
gp_data <- tibble(truth = unname(f_true), x = seq(-1, 1, by = 0.002)) %>% | |
mutate(y = truth + rnorm(length(truth), 0, 0.2)) | |
## sample 250 rows for further analysis | |
set.seed(seed) | |
r_samp <- sample_n(gp_data, size = 250) %>% | |
arrange(x) %>% | |
mutate(data_set = case_when(x < -0.8 ~ "test", | |
x > 0.8 ~ "test", | |
x > -0.45 & x < -0.36 ~ "test", | |
x > -0.05 & x < 0.05 ~ "test", | |
x > 0.45 & x < 0.6 ~ "test", | |
TRUE ~ "train")) | |
## plot the random sample and indicate the smaples we will extrapolate | |
## and interpolate values for | |
ggplot(r_samp, aes(x = x, y = y, colour = data_set)) + | |
geom_line(aes(y = truth, colour = NULL), show.legend = FALSE, alpha = 0.5) + | |
geom_point() + | |
scale_colour_brewer(palette = "Set1") | |
## Set the knots for B-spline | |
knots2 <- list(x = c(-2, -0.9, 0.9, 2)) | |
## Fit a few different spline bases | |
m2_bs <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 2)), | |
data = filter(r_samp, data_set == "train"), | |
method = "REML", knots = knots2) | |
m2_tprs <- gam(y ~ s(x, k = 50, bs = "tp"), | |
data = filter(r_samp, data_set == "train"), method = "REML") | |
m2_gp <- gam(y ~ s(x, k = 50, bs = "gp", m = c(3, 0.15)), | |
data = filter(r_samp, data_set == "train"), method = "REML") | |
## Data to predict at | |
new_data <- tibble(x = seq(-1.5, 1.5, by = 0.002)) | |
## Predict from the three models | |
p2_bs <- as_tibble(predict(m2_bs, new_data, se.fit = TRUE)) %>% | |
rename(fit_bs = fit, se_bs = se.fit) | |
p2_tprs <- as_tibble(predict(m2_tprs, new_data, se.fit = TRUE)) %>% | |
rename(fit_tprs = fit, se_tprs = se.fit) | |
p2_gp <- as_tibble(predict(m2_gp, new_data, se.fit = TRUE)) %>% | |
rename(fit_gp = fit, se_gp = se.fit) | |
## Put all predictions together | |
new_data_bases <- bind_cols(new_data, p2_tprs, p2_bs, p2_gp) %>% | |
pivot_longer(fit_tprs:se_gp, names_sep = '_', | |
names_to = c('variable', 'spline')) %>% | |
pivot_wider(names_from = variable, values_from = value) %>% | |
mutate(upr_ci = fit + (2 * se), lwr_ci = fit - (2 * se)) | |
ggplot(mapping = aes(x = x, y = y)) + | |
geom_ribbon(data = new_data_bases, | |
mapping = aes(ymin = lwr_ci, ymax = upr_ci, x = x, fill = spline), | |
inherit.aes = FALSE, alpha = 0.1) + | |
geom_point(data = r_samp, aes(colour = data_set)) + | |
geom_line(data = new_data_bases, aes(y = fit, x = x, colour2 = spline), | |
size = 1) %>% | |
rename_geom_aes(new_aes = c("colour" = "colour2")) + | |
scale_colour_brewer(palette = "Set1", aesthetics = "colour", name = "Data set") + | |
scale_colour_OkabeIto(aesthetics = "colour2", name = "Basis") + | |
scale_fill_OkabeIto(name = "Basis") + | |
coord_cartesian(ylim = c(-2, 2)) + | |
labs(title = "Extrapolating with splines", | |
subtitle = "How behaviour varies with different basis types") | |
## Compare penalties | |
m2_bs_2 <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 2)), | |
data = filter(r_samp, data_set == "train"), | |
method = "REML", knots = knots2) | |
m2_bs_1 <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 1)), | |
data = filter(r_samp, data_set == "train"), | |
method = "REML", knots = knots2) | |
m2_bs_0 <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 0)), | |
data = filter(r_samp, data_set == "train"), | |
method = "REML", knots = knots2) | |
p2_bs_2 <- as_tibble(predict(m2_bs_2, new_data, se.fit = TRUE)) %>% | |
rename(fit_bs_2 = fit, se_bs_2 = se.fit) | |
p2_bs_1 <- as_tibble(predict(m2_bs_1, new_data, se.fit = TRUE)) %>% | |
rename(fit_bs_1 = fit, se_bs_1 = se.fit) | |
p2_bs_0 <- as_tibble(predict(m2_bs_0, new_data, se.fit = TRUE)) %>% | |
rename(fit_bs_0 = fit, se_bs_0 = se.fit) | |
new_data_order <- bind_cols(new_data, p2_bs_2, p2_bs_1, p2_bs_0) %>% | |
pivot_longer(fit_bs_2:se_bs_0, names_sep = '_', | |
names_to = c('variable', 'spline', 'order')) %>% | |
pivot_wider(names_from = variable, values_from = value) %>% | |
mutate(upr_ci = fit + (2 * se), lwr_ci = fit - (2 * se)) | |
ggplot(mapping = aes(x = x, y = y)) + | |
geom_ribbon(data = new_data_order, | |
mapping = aes(ymin = lwr_ci, ymax = upr_ci, x = x, fill = order), | |
inherit.aes = FALSE, alpha = 0.1) + | |
geom_point(data = r_samp, aes(colour = data_set)) + | |
geom_line(data = new_data_order, aes(y = fit, x = x, colour2 = order), | |
size = 1) %>% | |
rename_geom_aes(new_aes = c("colour" = "colour2")) + | |
scale_colour_brewer(palette = "Set1", aesthetics = "colour", name = "Data set") + | |
scale_colour_OkabeIto(aesthetics = "colour2", name = "Penalty") + | |
scale_fill_OkabeIto(name = "Penalty") + | |
coord_cartesian(ylim = c(-2, 2)) + | |
labs(title = "Extrapolating with B splines", | |
subtitle = "How behaviour varies with penalties of different order") | |
m2_bs_21 <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 2, 1)), | |
data = filter(r_samp, data_set == "train"), | |
method = "REML", knots = knots2) | |
m2_bs_210 <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 2, 1, 0)), | |
data = filter(r_samp, data_set == "train"), | |
method = "REML", knots = knots2) | |
p2_bs_21 <- as_tibble(predict(m2_bs_21, new_data, se.fit = TRUE)) %>% | |
rename(fit_bs_21 = fit, se_bs_21 = se.fit) | |
p2_bs_210 <- as_tibble(predict(m2_bs_210, new_data, se.fit = TRUE)) %>% | |
rename(fit_bs_210 = fit, se_bs_210 = se.fit) | |
new_data_multi <- bind_cols(new_data, p2_bs_2, p2_bs_21, p2_bs_210) %>% | |
pivot_longer(fit_bs_2:se_bs_210, names_sep = '_', | |
names_to = c('variable', 'spline', 'order')) %>% | |
pivot_wider(names_from = variable, values_from = value) %>% | |
mutate(upr_ci = fit + (2 * se), lwr_ci = fit - (2 * se), | |
penalty = case_when(order == "2" ~ "2", | |
order == "21" ~ "2, 1", | |
order == "210" ~ "2, 1, 0")) | |
ggplot(mapping = aes(x = x, y = y)) + | |
geom_ribbon(data = new_data_multi, | |
mapping = aes(ymin = lwr_ci, ymax = upr_ci, x = x, fill = penalty), | |
inherit.aes = FALSE, alpha = 0.1) + | |
geom_point(data = r_samp, aes(colour = data_set)) + | |
geom_line(data = new_data_multi, aes(y = fit, x = x, colour2 = penalty), | |
size = 1) %>% | |
rename_geom_aes(new_aes = c("colour" = "colour2")) + | |
scale_colour_brewer(palette = "Set1", aesthetics = "colour", name = "Data set") + | |
scale_colour_OkabeIto(aesthetics = "colour2", name = "Penalty") + | |
scale_fill_OkabeIto(name = "Penalty") + | |
coord_cartesian(ylim = c(-2, 2)) + | |
labs(title = "Extrapolating with B splines", | |
subtitle = "How behaviour changes when combining multiple penalties") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment