Skip to content

Instantly share code, notes, and snippets.

@gavinsimpson
Created May 7, 2020 19:18
Show Gist options
  • Save gavinsimpson/727900bcd634fa530d2bd7316aa9d065 to your computer and use it in GitHub Desktop.
Save gavinsimpson/727900bcd634fa530d2bd7316aa9d065 to your computer and use it in GitHub Desktop.
Animated spline basis functions
## plots and animations of how basis functions are used to make
## splines and how these are fitted to data
library('ggplot2')
library('tibble')
library('tidyr')
library('dplyr')
library('mgcv')
library('mvnfast')
library('purrr')
library('gganimate')
theme_set(theme_minimal())
f <- function(x) {
x^11 * (10 * (1 - x))^6 + ((10 * (10 * x)^3) * (1 - x)^10)
}
draw_beta <- function(n, k, mu = 1, sigma = 1) {
rmvn(n = n, mu = rep(mu, k), sigma = diag(rep(sigma, k)))
}
weight_basis <- function(bf, x, n = 1, k, ...) {
beta <- draw_beta(n = n, k = k, ...)
out <- sweep(bf, 2L, beta, '*')
colnames(out) <- paste0('f', seq_along(beta))
out <- as_tibble(out)
out <- add_column(out, x = x)
out <- pivot_longer(out, -x, names_to = 'bf', values_to = 'y')
out
}
random_bases <- function(bf, x, draws = 10, k, ...) {
out <- rerun(draws, weight_basis(bf, x = x, k = k, ...))
out <- bind_rows(out)
out <- add_column(out, draw = rep(seq_len(draws), each = length(x) * k),
.before = 1L)
class(out) <- c("random_bases", class(out))
out
}
plot.random_bases <- function(x, facet = FALSE) {
plt <- ggplot(x, aes(x = x, y = y, colour = bf)) +
geom_line(lwd = 1) +
guides(colour = FALSE)
if (facet) {
plt + facet_wrap(~ draw)
}
plt
}
normalize <- function(x) {
rx <- range(x)
z <- (x - rx[1]) / (rx[2] - rx[1])
z
}
set.seed(1)
N <- 500
data <- tibble(x = runif(N),
ytrue = f(x),
ycent = ytrue - mean(ytrue),
yobs = ycent + rnorm(N, sd = 0.5),
ynorm = normalize(yobs))
k <- 10
knots <- with(data, list(x = seq(min(x), max(x), length = k)))
sm <- smoothCon(s(x, k = k, bs = "cr"), data = data, knots = knots)[[1]]$X
colnames(sm) <- levs <- paste0("f", seq_len(k))
basis <- pivot_longer(cbind(sm, data), -(x:yobs), names_to = 'bf')
basis
ggplot(basis, aes(x = x, y = value, colour = bf)) +
geom_line(lwd = 1, alpha = 0.4) +
guides(colour = FALSE)
set.seed(2)
bfuns <- random_bases(sm, data$x, draws = 20, k = k)
smooth <- bfuns %>%
group_by(draw, x) %>%
summarise(spline = sum(y)) %>%
ungroup()
p1 <- ggplot(smooth) +
geom_line(data = smooth, aes(x = x, y = spline), lwd = 1.5) +
labs(y = 'f(x)', x = 'x') +
theme_minimal(base_size = 14, base_family = 'Titillium')
p1
smooth_funs <- animate(
p1 + transition_states(draw, transition_length = 4, state_length = 2) +
ease_aes('cubic-in-out'),
nframes = 200, height = 800/1.77777, width = 800, res = 120)
p <- plot(bfuns) + geom_line(data = smooth, aes(x = x, y = spline),
inherit.aes = FALSE, lwd = 1.5) +
theme(text = element_text(size = 16)) +
labs(x = 'x', y = expression(f(x)))
animate(
p + transition_states(draw, transition_length = 4, state_length = 2) +
ease_aes('cubic-in-out'),
nframes = 200)
data_plt <- ggplot(data, aes(x = x, y = ycent)) +
geom_line(col = 'goldenrod', lwd = 2) +
geom_point(aes(y = yobs), alpha = 0.2) +
theme(text = element_text(size = 16))
data_plt
sm2 <- smoothCon(s(x, k = k, bs = "cr"), data = data, knots = knots)[[1]]$X
beta <- coef(lm(ycent ~ sm2 - 1, data = data))
wtbasis <- sweep(sm2, 2L, beta, FUN = "*")
colnames(wtbasis) <- colnames(sm2) <- paste0("F", seq_len(k))
## create stacked unweighted and weighted basis
basis <- as_tibble(rbind(sm2, wtbasis)) %>%
add_column(x = rep(data$x, times = 2),
type = rep(c('unweighted', 'weighted'), each = nrow(sm2)),
.before = 1L)
##data <- cbind(data, fitted = rowSums(scbasis))
wtbasis <- as_tibble(rbind(sm2, wtbasis)) %>%
add_column(x = rep(data$x, times = 2),
fitted = rowSums(.),
type = rep(c('unweighted', 'weighted'), each = nrow(sm2))) %>%
pivot_longer(-(x:type), names_to = 'bf')
basis <- pivot_longer(basis, -(x:type), names_to = 'bf')
ggplot(wtbasis, aes(x = x, y = value, colour = bf)) +
geom_line(lwd = 1, alpha = 0.4) +
guides(colour = FALSE) +
theme(text = element_text(size = 16)) +
labs(x = 'x', y = expression(f(x)))
ggplot(basis, aes(x = x, y = value, colour = bf)) +
geom_line(lwd = 1, alpha = 0.4) +
guides(colour = FALSE) +
theme(text = element_text(size = 16)) +
labs(x = 'x', y = expression(f(x)))
data_plt + geom_line(data = wtbasis,
mapping = aes(x = x, y = value, colour = bf),
lwd = 1, alpha = 0.5) +
guides(colour = FALSE) +
theme(text = element_text(size = 16))
data_plt + geom_line(data = wtbasis,
mapping = aes(x = x, y = fitted), lwd = 1.5, colour = 'steelblue2', alpha = 0.75) +
geom_line(data = wtbasis,
mapping = aes(x = x, y = value, colour = bf),
lwd = 1, alpha = 0.7) +
guides(colour = FALSE) +
labs(y = expression(f(x)), x = 'x') +
theme(text = element_text(size = 16))
p3 <- ggplot(data, aes(x = x, y = ycent)) +
geom_point(aes(y = yobs), alpha = 0.2) +
geom_line(data = basis,
mapping = aes(x = x, y = value, colour = bf),
lwd = 1, alpha = 0.5) +
geom_line(data = wtbasis,
mapping = aes(x = x, y = fitted), lwd = 1, colour = 'black', alpha = 0.75) +
guides(colour = FALSE) +
labs(y = 'f(x)', x = 'x', title = 'GAMs: learning from data',
subtitle = 'Cubic regression spline',
caption = '@ucfagls') +
theme_minimal(base_size = 14, base_family = 'Titillium')
p3
animate(p3 + transition_states(type, transition_length = 4, state_length = 2) +
ease_aes('cubic-in-out'),
nframes = 100, height = 700, width = 800, res = 120)
sm_tprs <- smoothCon(s(x, k = k, bs = "tp"), absorb = TRUE, data = data)[[1]]
E <- t(mroot(sm_tprs$S[[1]])) # square root penalty
sm_tprsX <- rbind(sm_tprs$X, 0.1 * E) # augmented model matrix
y <- c(data$yobs, rep(0, nrow(E))) # augmented data
beta <- coef(lm(y ~ sm_tprsX, data = data))
spline <- sweep(sm_tprs$X, 2L, beta[-1], FUN = "*")
sm_tprs <- sm_tprs$X
colnames(spline) <- colnames(sm_tprs) <- paste0("F", seq_len(k-1))
## create stacked unweighted and weighted basis
basis <- as_tibble(rbind(sm_tprs, spline)) %>%
add_column(x = rep(data$x, times = 2),
type = rep(c('unweighted', 'weighted'), each = nrow(sm_tprs)),
.before = 1L)
##data <- cbind(data, fitted = rowSums(scbasis))
spline <- as_tibble(rbind(sm_tprs, spline)) %>%
add_column(x = rep(data$x, times = 2),
fitted = rowSums(.) + beta[1L],
type = rep(c('unweighted', 'weighted'), each = nrow(sm_tprs))) %>%
pivot_longer(-(x:type), names_to = 'bf')
basis <- pivot_longer(basis, -(x:type), names_to = 'bf')
ptprs <- ggplot(data, aes(x = x, y = yobs)) +
geom_point(alpha = 0.2) +
geom_line(data = basis,
mapping = aes(x = x, y = value, colour = bf),
lwd = 1, alpha = 0.5) +
geom_line(data = spline,
mapping = aes(x = x, y = fitted), lwd = 1, colour = 'black', alpha = 0.75) +
guides(colour = FALSE) +
labs(y = 'f(x)', x = 'x', title = 'GAMs: learning from data',
subtitle = 'Penalised thin plate regression spline',
caption = '@ucfagls') +
theme_minimal(base_size = 14, base_family = 'Titillium')
ptprs
animate(ptprs + transition_states(type, transition_length = 4, state_length = 2) +
ease_aes('cubic-in-out'),
nframes = 100, height = 700, width = 800, res = 120)
sm_gp <- smoothCon(s(x, bs = 'gp', k = k, m = c(3, 0.25)), data = data)[[1]]$X
## E <- t(mroot(sm_gp$S[[1]])) # square root penalty
## sm_gpX <- rbind(sm_gp$X, 0.1 * # augmented model matrix
## y <- c(data$yobs, rep(0, nrow(E))) # augmented data
beta <- coef(lm(yobs ~ sm_gp - 1, data = data))
spline <- sweep(sm_gp, 2L, beta, FUN = "*")
colnames(spline) <- colnames(sm_gp) <- paste0("F", seq_len(k))
## create stacked unweighted and weighted basis
basis <- as_tibble(rbind(sm_gp, spline)) %>%
add_column(x = rep(data$x, times = 2),
type = rep(c('unweighted', 'weighted'), each = nrow(sm_gp)),
.before = 1L)
##data <- cbind(data, fitted = rowSums(scbasis))
spline <- as_tibble(rbind(sm_gp, spline)) %>%
add_column(x = rep(data$x, times = 2),
fitted = rowSums(.),
type = rep(c('unweighted', 'weighted'), each = nrow(sm_gp))) %>%
pivot_longer(-(x:type), names_to = 'bf')
basis <- pivot_longer(basis, -(x:type), names_to = 'bf')
pgp <- ggplot(data, aes(x = x, y = yobs)) +
geom_point(alpha = 0.2) +
geom_line(data = basis,
mapping = aes(x = x, y = value, colour = bf),
lwd = 1, alpha = 0.5) +
geom_line(data = spline,
mapping = aes(x = x, y = fitted), lwd = 1, colour = 'black', alpha = 0.75) +
guides(colour = FALSE) +
labs(y = 'f(x)', x = 'x', title = 'GAMs: learning from data',
subtitle = 'Gaussian process — Matérn(κ=1.5; ρ=0.25)',
caption = '@ucfagls') +
theme_minimal(base_size = 14, base_family = 'Titillium')
pgp
animate(pgp + transition_states(type, transition_length = 4, state_length = 2) +
ease_aes('cubic-in-out'),
nframes = 100, height = 700, width = 800, res = 120)
ggplot(basis, aes(x = x, y = value, colour = bf)) + geom_line() + facet_wrap(~ type, scales = 'free_y')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment