Skip to content

Instantly share code, notes, and snippets.

@gavinsimpson
Last active October 16, 2023 21:21
Show Gist options
  • Save gavinsimpson/99741b2bfc7a0bbf300e55cd7110eee3 to your computer and use it in GitHub Desktop.
Save gavinsimpson/99741b2bfc7a0bbf300e55cd7110eee3 to your computer and use it in GitHub Desktop.
Extrapolating with penalised b-splines and mgcv
## Can we do better extrapolation than the TPRS example used by Gabriel
## Riutort-Mayol et al in their recent paper on low rank approximations to
## Gaussian Processes? https://arxiv.org/abs/2004.11408
## Packages
library('ggplot2')
library('tibble')
library('tidyr')
library('dplyr')
library('mgcv')
## remotes::install_github("clauswilke/colorblindr")
library('colorblindr')
## remotes::install_github("clauswilke/relayer")
library('relayer')
## Univariate example -----------------------------------------------------------
## Load realisations from true function
load(url("https://github.com/gabriuma/basis_functions_approach_to_GP/raw/master/Paper/Case-study_1D-Simulated-data/r-code/f_true.rData"))
## Simulate noisy data from the function
seed <- 1234
set.seed(seed)
## reference data
gp_data <- tibble(truth = unname(f_true), x = seq(-1, 1, by = 0.002)) %>%
mutate(y = truth + rnorm(length(truth), 0, 0.2))
## sample 250 rows for further analysis
set.seed(seed)
r_samp <- sample_n(gp_data, size = 250) %>%
arrange(x) %>%
mutate(data_set = case_when(x < -0.8 ~ "test",
x > 0.8 ~ "test",
x > -0.45 & x < -0.36 ~ "test",
x > -0.05 & x < 0.05 ~ "test",
x > 0.45 & x < 0.6 ~ "test",
TRUE ~ "train"))
## plot the random sample and indicate the smaples we will extrapolate
## and interpolate values for
ggplot(r_samp, aes(x = x, y = y, colour = data_set)) +
geom_line(aes(y = truth, colour = NULL), show.legend = FALSE, alpha = 0.5) +
geom_point() +
scale_colour_brewer(palette = "Set1")
## Set the knots for B-spline
knots2 <- list(x = c(-2, -0.9, 0.9, 2))
## Fit a few different spline bases
m2_bs <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 2)),
data = filter(r_samp, data_set == "train"),
method = "REML", knots = knots2)
m2_tprs <- gam(y ~ s(x, k = 50, bs = "tp"),
data = filter(r_samp, data_set == "train"), method = "REML")
m2_gp <- gam(y ~ s(x, k = 50, bs = "gp", m = c(3, 0.15)),
data = filter(r_samp, data_set == "train"), method = "REML")
## Data to predict at
new_data <- tibble(x = seq(-1.5, 1.5, by = 0.002))
## Predict from the three models
p2_bs <- as_tibble(predict(m2_bs, new_data, se.fit = TRUE)) %>%
rename(fit_bs = fit, se_bs = se.fit)
p2_tprs <- as_tibble(predict(m2_tprs, new_data, se.fit = TRUE)) %>%
rename(fit_tprs = fit, se_tprs = se.fit)
p2_gp <- as_tibble(predict(m2_gp, new_data, se.fit = TRUE)) %>%
rename(fit_gp = fit, se_gp = se.fit)
## Put all predictions together
new_data_bases <- bind_cols(new_data, p2_tprs, p2_bs, p2_gp) %>%
pivot_longer(fit_tprs:se_gp, names_sep = '_',
names_to = c('variable', 'spline')) %>%
pivot_wider(names_from = variable, values_from = value) %>%
mutate(upr_ci = fit + (2 * se), lwr_ci = fit - (2 * se))
ggplot(mapping = aes(x = x, y = y)) +
geom_ribbon(data = new_data_bases,
mapping = aes(ymin = lwr_ci, ymax = upr_ci, x = x, fill = spline),
inherit.aes = FALSE, alpha = 0.1) +
geom_point(data = r_samp, aes(colour = data_set)) +
geom_line(data = new_data_bases, aes(y = fit, x = x, colour2 = spline),
size = 1) %>%
rename_geom_aes(new_aes = c("colour" = "colour2")) +
scale_colour_brewer(palette = "Set1", aesthetics = "colour", name = "Data set") +
scale_colour_OkabeIto(aesthetics = "colour2", name = "Basis") +
scale_fill_OkabeIto(name = "Basis") +
coord_cartesian(ylim = c(-2, 2)) +
labs(title = "Extrapolating with splines",
subtitle = "How behaviour varies with different basis types")
## Compare penalties
m2_bs_2 <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 2)),
data = filter(r_samp, data_set == "train"),
method = "REML", knots = knots2)
m2_bs_1 <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 1)),
data = filter(r_samp, data_set == "train"),
method = "REML", knots = knots2)
m2_bs_0 <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 0)),
data = filter(r_samp, data_set == "train"),
method = "REML", knots = knots2)
p2_bs_2 <- as_tibble(predict(m2_bs_2, new_data, se.fit = TRUE)) %>%
rename(fit_bs_2 = fit, se_bs_2 = se.fit)
p2_bs_1 <- as_tibble(predict(m2_bs_1, new_data, se.fit = TRUE)) %>%
rename(fit_bs_1 = fit, se_bs_1 = se.fit)
p2_bs_0 <- as_tibble(predict(m2_bs_0, new_data, se.fit = TRUE)) %>%
rename(fit_bs_0 = fit, se_bs_0 = se.fit)
new_data_order <- bind_cols(new_data, p2_bs_2, p2_bs_1, p2_bs_0) %>%
pivot_longer(fit_bs_2:se_bs_0, names_sep = '_',
names_to = c('variable', 'spline', 'order')) %>%
pivot_wider(names_from = variable, values_from = value) %>%
mutate(upr_ci = fit + (2 * se), lwr_ci = fit - (2 * se))
ggplot(mapping = aes(x = x, y = y)) +
geom_ribbon(data = new_data_order,
mapping = aes(ymin = lwr_ci, ymax = upr_ci, x = x, fill = order),
inherit.aes = FALSE, alpha = 0.1) +
geom_point(data = r_samp, aes(colour = data_set)) +
geom_line(data = new_data_order, aes(y = fit, x = x, colour2 = order),
size = 1) %>%
rename_geom_aes(new_aes = c("colour" = "colour2")) +
scale_colour_brewer(palette = "Set1", aesthetics = "colour", name = "Data set") +
scale_colour_OkabeIto(aesthetics = "colour2", name = "Penalty") +
scale_fill_OkabeIto(name = "Penalty") +
coord_cartesian(ylim = c(-2, 2)) +
labs(title = "Extrapolating with B splines",
subtitle = "How behaviour varies with penalties of different order")
m2_bs_21 <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 2, 1)),
data = filter(r_samp, data_set == "train"),
method = "REML", knots = knots2)
m2_bs_210 <- gam(y ~ s(x, k = 50, bs = "bs", m = c(3, 2, 1, 0)),
data = filter(r_samp, data_set == "train"),
method = "REML", knots = knots2)
p2_bs_21 <- as_tibble(predict(m2_bs_21, new_data, se.fit = TRUE)) %>%
rename(fit_bs_21 = fit, se_bs_21 = se.fit)
p2_bs_210 <- as_tibble(predict(m2_bs_210, new_data, se.fit = TRUE)) %>%
rename(fit_bs_210 = fit, se_bs_210 = se.fit)
new_data_multi <- bind_cols(new_data, p2_bs_2, p2_bs_21, p2_bs_210) %>%
pivot_longer(fit_bs_2:se_bs_210, names_sep = '_',
names_to = c('variable', 'spline', 'order')) %>%
pivot_wider(names_from = variable, values_from = value) %>%
mutate(upr_ci = fit + (2 * se), lwr_ci = fit - (2 * se),
penalty = case_when(order == "2" ~ "2",
order == "21" ~ "2, 1",
order == "210" ~ "2, 1, 0"))
ggplot(mapping = aes(x = x, y = y)) +
geom_ribbon(data = new_data_multi,
mapping = aes(ymin = lwr_ci, ymax = upr_ci, x = x, fill = penalty),
inherit.aes = FALSE, alpha = 0.1) +
geom_point(data = r_samp, aes(colour = data_set)) +
geom_line(data = new_data_multi, aes(y = fit, x = x, colour2 = penalty),
size = 1) %>%
rename_geom_aes(new_aes = c("colour" = "colour2")) +
scale_colour_brewer(palette = "Set1", aesthetics = "colour", name = "Data set") +
scale_colour_OkabeIto(aesthetics = "colour2", name = "Penalty") +
scale_fill_OkabeIto(name = "Penalty") +
coord_cartesian(ylim = c(-2, 2)) +
labs(title = "Extrapolating with B splines",
subtitle = "How behaviour changes when combining multiple penalties")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment