Created
May 7, 2020 19:18
-
-
Save gavinsimpson/727900bcd634fa530d2bd7316aa9d065 to your computer and use it in GitHub Desktop.
Animated spline basis functions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## plots and animations of how basis functions are used to make | |
## splines and how these are fitted to data | |
library('ggplot2') | |
library('tibble') | |
library('tidyr') | |
library('dplyr') | |
library('mgcv') | |
library('mvnfast') | |
library('purrr') | |
library('gganimate') | |
theme_set(theme_minimal()) | |
f <- function(x) { | |
x^11 * (10 * (1 - x))^6 + ((10 * (10 * x)^3) * (1 - x)^10) | |
} | |
draw_beta <- function(n, k, mu = 1, sigma = 1) { | |
rmvn(n = n, mu = rep(mu, k), sigma = diag(rep(sigma, k))) | |
} | |
weight_basis <- function(bf, x, n = 1, k, ...) { | |
beta <- draw_beta(n = n, k = k, ...) | |
out <- sweep(bf, 2L, beta, '*') | |
colnames(out) <- paste0('f', seq_along(beta)) | |
out <- as_tibble(out) | |
out <- add_column(out, x = x) | |
out <- pivot_longer(out, -x, names_to = 'bf', values_to = 'y') | |
out | |
} | |
random_bases <- function(bf, x, draws = 10, k, ...) { | |
out <- rerun(draws, weight_basis(bf, x = x, k = k, ...)) | |
out <- bind_rows(out) | |
out <- add_column(out, draw = rep(seq_len(draws), each = length(x) * k), | |
.before = 1L) | |
class(out) <- c("random_bases", class(out)) | |
out | |
} | |
plot.random_bases <- function(x, facet = FALSE) { | |
plt <- ggplot(x, aes(x = x, y = y, colour = bf)) + | |
geom_line(lwd = 1) + | |
guides(colour = FALSE) | |
if (facet) { | |
plt + facet_wrap(~ draw) | |
} | |
plt | |
} | |
normalize <- function(x) { | |
rx <- range(x) | |
z <- (x - rx[1]) / (rx[2] - rx[1]) | |
z | |
} | |
set.seed(1) | |
N <- 500 | |
data <- tibble(x = runif(N), | |
ytrue = f(x), | |
ycent = ytrue - mean(ytrue), | |
yobs = ycent + rnorm(N, sd = 0.5), | |
ynorm = normalize(yobs)) | |
k <- 10 | |
knots <- with(data, list(x = seq(min(x), max(x), length = k))) | |
sm <- smoothCon(s(x, k = k, bs = "cr"), data = data, knots = knots)[[1]]$X | |
colnames(sm) <- levs <- paste0("f", seq_len(k)) | |
basis <- pivot_longer(cbind(sm, data), -(x:yobs), names_to = 'bf') | |
basis | |
ggplot(basis, aes(x = x, y = value, colour = bf)) + | |
geom_line(lwd = 1, alpha = 0.4) + | |
guides(colour = FALSE) | |
set.seed(2) | |
bfuns <- random_bases(sm, data$x, draws = 20, k = k) | |
smooth <- bfuns %>% | |
group_by(draw, x) %>% | |
summarise(spline = sum(y)) %>% | |
ungroup() | |
p1 <- ggplot(smooth) + | |
geom_line(data = smooth, aes(x = x, y = spline), lwd = 1.5) + | |
labs(y = 'f(x)', x = 'x') + | |
theme_minimal(base_size = 14, base_family = 'Titillium') | |
p1 | |
smooth_funs <- animate( | |
p1 + transition_states(draw, transition_length = 4, state_length = 2) + | |
ease_aes('cubic-in-out'), | |
nframes = 200, height = 800/1.77777, width = 800, res = 120) | |
p <- plot(bfuns) + geom_line(data = smooth, aes(x = x, y = spline), | |
inherit.aes = FALSE, lwd = 1.5) + | |
theme(text = element_text(size = 16)) + | |
labs(x = 'x', y = expression(f(x))) | |
animate( | |
p + transition_states(draw, transition_length = 4, state_length = 2) + | |
ease_aes('cubic-in-out'), | |
nframes = 200) | |
data_plt <- ggplot(data, aes(x = x, y = ycent)) + | |
geom_line(col = 'goldenrod', lwd = 2) + | |
geom_point(aes(y = yobs), alpha = 0.2) + | |
theme(text = element_text(size = 16)) | |
data_plt | |
sm2 <- smoothCon(s(x, k = k, bs = "cr"), data = data, knots = knots)[[1]]$X | |
beta <- coef(lm(ycent ~ sm2 - 1, data = data)) | |
wtbasis <- sweep(sm2, 2L, beta, FUN = "*") | |
colnames(wtbasis) <- colnames(sm2) <- paste0("F", seq_len(k)) | |
## create stacked unweighted and weighted basis | |
basis <- as_tibble(rbind(sm2, wtbasis)) %>% | |
add_column(x = rep(data$x, times = 2), | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm2)), | |
.before = 1L) | |
##data <- cbind(data, fitted = rowSums(scbasis)) | |
wtbasis <- as_tibble(rbind(sm2, wtbasis)) %>% | |
add_column(x = rep(data$x, times = 2), | |
fitted = rowSums(.), | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm2))) %>% | |
pivot_longer(-(x:type), names_to = 'bf') | |
basis <- pivot_longer(basis, -(x:type), names_to = 'bf') | |
ggplot(wtbasis, aes(x = x, y = value, colour = bf)) + | |
geom_line(lwd = 1, alpha = 0.4) + | |
guides(colour = FALSE) + | |
theme(text = element_text(size = 16)) + | |
labs(x = 'x', y = expression(f(x))) | |
ggplot(basis, aes(x = x, y = value, colour = bf)) + | |
geom_line(lwd = 1, alpha = 0.4) + | |
guides(colour = FALSE) + | |
theme(text = element_text(size = 16)) + | |
labs(x = 'x', y = expression(f(x))) | |
data_plt + geom_line(data = wtbasis, | |
mapping = aes(x = x, y = value, colour = bf), | |
lwd = 1, alpha = 0.5) + | |
guides(colour = FALSE) + | |
theme(text = element_text(size = 16)) | |
data_plt + geom_line(data = wtbasis, | |
mapping = aes(x = x, y = fitted), lwd = 1.5, colour = 'steelblue2', alpha = 0.75) + | |
geom_line(data = wtbasis, | |
mapping = aes(x = x, y = value, colour = bf), | |
lwd = 1, alpha = 0.7) + | |
guides(colour = FALSE) + | |
labs(y = expression(f(x)), x = 'x') + | |
theme(text = element_text(size = 16)) | |
p3 <- ggplot(data, aes(x = x, y = ycent)) + | |
geom_point(aes(y = yobs), alpha = 0.2) + | |
geom_line(data = basis, | |
mapping = aes(x = x, y = value, colour = bf), | |
lwd = 1, alpha = 0.5) + | |
geom_line(data = wtbasis, | |
mapping = aes(x = x, y = fitted), lwd = 1, colour = 'black', alpha = 0.75) + | |
guides(colour = FALSE) + | |
labs(y = 'f(x)', x = 'x', title = 'GAMs: learning from data', | |
subtitle = 'Cubic regression spline', | |
caption = '@ucfagls') + | |
theme_minimal(base_size = 14, base_family = 'Titillium') | |
p3 | |
animate(p3 + transition_states(type, transition_length = 4, state_length = 2) + | |
ease_aes('cubic-in-out'), | |
nframes = 100, height = 700, width = 800, res = 120) | |
sm_tprs <- smoothCon(s(x, k = k, bs = "tp"), absorb = TRUE, data = data)[[1]] | |
E <- t(mroot(sm_tprs$S[[1]])) # square root penalty | |
sm_tprsX <- rbind(sm_tprs$X, 0.1 * E) # augmented model matrix | |
y <- c(data$yobs, rep(0, nrow(E))) # augmented data | |
beta <- coef(lm(y ~ sm_tprsX, data = data)) | |
spline <- sweep(sm_tprs$X, 2L, beta[-1], FUN = "*") | |
sm_tprs <- sm_tprs$X | |
colnames(spline) <- colnames(sm_tprs) <- paste0("F", seq_len(k-1)) | |
## create stacked unweighted and weighted basis | |
basis <- as_tibble(rbind(sm_tprs, spline)) %>% | |
add_column(x = rep(data$x, times = 2), | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm_tprs)), | |
.before = 1L) | |
##data <- cbind(data, fitted = rowSums(scbasis)) | |
spline <- as_tibble(rbind(sm_tprs, spline)) %>% | |
add_column(x = rep(data$x, times = 2), | |
fitted = rowSums(.) + beta[1L], | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm_tprs))) %>% | |
pivot_longer(-(x:type), names_to = 'bf') | |
basis <- pivot_longer(basis, -(x:type), names_to = 'bf') | |
ptprs <- ggplot(data, aes(x = x, y = yobs)) + | |
geom_point(alpha = 0.2) + | |
geom_line(data = basis, | |
mapping = aes(x = x, y = value, colour = bf), | |
lwd = 1, alpha = 0.5) + | |
geom_line(data = spline, | |
mapping = aes(x = x, y = fitted), lwd = 1, colour = 'black', alpha = 0.75) + | |
guides(colour = FALSE) + | |
labs(y = 'f(x)', x = 'x', title = 'GAMs: learning from data', | |
subtitle = 'Penalised thin plate regression spline', | |
caption = '@ucfagls') + | |
theme_minimal(base_size = 14, base_family = 'Titillium') | |
ptprs | |
animate(ptprs + transition_states(type, transition_length = 4, state_length = 2) + | |
ease_aes('cubic-in-out'), | |
nframes = 100, height = 700, width = 800, res = 120) | |
sm_gp <- smoothCon(s(x, bs = 'gp', k = k, m = c(3, 0.25)), data = data)[[1]]$X | |
## E <- t(mroot(sm_gp$S[[1]])) # square root penalty | |
## sm_gpX <- rbind(sm_gp$X, 0.1 * # augmented model matrix | |
## y <- c(data$yobs, rep(0, nrow(E))) # augmented data | |
beta <- coef(lm(yobs ~ sm_gp - 1, data = data)) | |
spline <- sweep(sm_gp, 2L, beta, FUN = "*") | |
colnames(spline) <- colnames(sm_gp) <- paste0("F", seq_len(k)) | |
## create stacked unweighted and weighted basis | |
basis <- as_tibble(rbind(sm_gp, spline)) %>% | |
add_column(x = rep(data$x, times = 2), | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm_gp)), | |
.before = 1L) | |
##data <- cbind(data, fitted = rowSums(scbasis)) | |
spline <- as_tibble(rbind(sm_gp, spline)) %>% | |
add_column(x = rep(data$x, times = 2), | |
fitted = rowSums(.), | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm_gp))) %>% | |
pivot_longer(-(x:type), names_to = 'bf') | |
basis <- pivot_longer(basis, -(x:type), names_to = 'bf') | |
pgp <- ggplot(data, aes(x = x, y = yobs)) + | |
geom_point(alpha = 0.2) + | |
geom_line(data = basis, | |
mapping = aes(x = x, y = value, colour = bf), | |
lwd = 1, alpha = 0.5) + | |
geom_line(data = spline, | |
mapping = aes(x = x, y = fitted), lwd = 1, colour = 'black', alpha = 0.75) + | |
guides(colour = FALSE) + | |
labs(y = 'f(x)', x = 'x', title = 'GAMs: learning from data', | |
subtitle = 'Gaussian process — Matérn(κ=1.5; ρ=0.25)', | |
caption = '@ucfagls') + | |
theme_minimal(base_size = 14, base_family = 'Titillium') | |
pgp | |
animate(pgp + transition_states(type, transition_length = 4, state_length = 2) + | |
ease_aes('cubic-in-out'), | |
nframes = 100, height = 700, width = 800, res = 120) | |
ggplot(basis, aes(x = x, y = value, colour = bf)) + geom_line() + facet_wrap(~ type, scales = 'free_y') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment