Skip to content

Instantly share code, notes, and snippets.

@jfaganUK
Created February 10, 2020 19:32
Show Gist options
  • Save jfaganUK/60aa1f910cca59ff0a44de6198700e44 to your computer and use it in GitHub Desktop.
Save jfaganUK/60aa1f910cca59ff0a44de6198700e44 to your computer and use it in GitHub Desktop.
## Animate fitting ordinary least squares
# Author: Jesse Michael Fagan, PhD
# Date: 2020-02-10T19:29
## Packages ########################################
rm(list=ls())
gc()
library(magick)
library(gganimate)
library(transformr)
library(tidyverse)
# xkcd <- read_csv('./curve_fitting.csv')
xkcd <- structure(list(v1 = c(0, 9.84848484848485, 14.3939393939394,
36.3636363636364, 34.0909090909091, 24.2424242424242, 21.2121212121212,
18.9393939393939, 12.1212121212121, 18.1818181818182, 39.3939393939394,
43.9393939393939, 40.9090909090909, 45.4545454545455, 50.7575757575758,
59.0909090909091, 64.3939393939394, 70.4545454545455, 76.5151515151515,
77.2727272727273, 76.5151515151515, 92.4242424242424, 89.3939393939394,
91.6666666666667, 87.1212121212121, 84.0909090909091, 77.2727272727273,
95.4545454545455, 93.1818181818182, 93.1818181818182, 100),
v2 = c(9.2436974789916,
41.1764705882353, 68.9075630252101, 51.2605042016807, 33.6134453781513,
26.890756302521, 21.0084033613445, 26.890756302521, 16.8067226890756,
11.7647058823529, 24.3697478991597, 21.0084033613445, 15.1260504201681,
12.6050420168067, 8.40336134453782, 11.7647058823529, 34.453781512605,
50.4201680672269, 57.1428571428571, 50.4201680672269, 42.0168067226891,
42.0168067226891, 0, 59.6638655462185, 68.9075630252101, 82.3529411764706,
100, 84.8739495798319, 76.4705882352941, 69.7478991596639, 67.2268907563025
)), class = c("data.frame"), row.names = c(NA,-31L),
spec = structure(list(cols = list(v1 = structure(list(), class = c("collector_double","collector")),
v2 = structure(list(), class = c("collector_double","collector"))),
default = structure(list(), class = c("collector_guess", "collector")), skip = 1), class = "col_spec"))
ggplot(xkcd, aes(x = v1, y = v2)) +
geom_point() +
geom_smooth(method = 'lm')
### Squared error ##############################################
squared_error <- function(param) {
e <- xkcd$v1 * param[1] + param[2]
es <- sum((e - xkcd$v2)^2)
i <<- i + 1
params_l[[i]] <<- data.frame(i = i, squared_error = es, b1 = param[1], b0 = param[2])
return(es)
}
i <- 0
params_l <- list()
o <- optim(c(b1 = -1.8, b0 = -18), squared_error, control = list(trace = T))
params <- do.call('rbind', params_l)
### Create animation frame ######################################
record_anim <- lapply(1:nrow(params), function(i) {
b0 <- params$b0[i]
b1 <- params$b1[i]
d <- xkcd %>%
mutate(v2_pred = b0 + b1 * v1) %>%
mutate(residual = v2 - v2_pred) %>%
mutate(residual2 = residual^2) %>%
mutate(residual2z = residual2 %>% scale %>% as.numeric) %>%
mutate(m_v2 = mean(v2)) %>%
mutate(param_b0 = b0, param_b1 = b1, state = i)
return(d)
}) %>% bind_rows
record_anim <- record_anim %>%
mutate(f = sprintf('v2 = %1.2f v1 + %1.2f', param_b1, param_b0)) %>%
mutate(f = ifelse(v1 == min(v1), f, NA_character_))
### Create animation ##################################
p1 <- ggplot(record_anim) +
geom_hline(yintercept = 0) +
geom_vline(xintercept = 0) +
geom_rect(aes(xmin=v1, xmax = v1 + residual,ymin = v2,
ymax = v2_pred,fill = residual2z), alpha = 0.8) +
geom_point(aes(x = v1, y = v2)) +
geom_line(aes(x = v1, y = v2_pred), color = 'black', size = 1) +
geom_text(x = 0.1, y = 80, hjust = 0, vjust = 0, aes(label = f)) +
scale_fill_viridis_c('Squared Residual', option = 'plasma') +
scale_x_continuous('v1') +
scale_y_continuous('v2') +
coord_equal(xlim = c(-25,125), ylim = c(-25, 125)) +
labs(title = 'Estimating OLS Regression',
subtitle = 'Nelder-Mead minimizer step: {closest_state}',
caption = 'Data source: https://xkcd.com/2048/ \n Created by @jessemfagan') +
theme_minimal() +
theme(legend.position = 'none') +
transition_states(state, transition_length = 3, state_length = 0, wrap = F) +
ease_aes('linear')
options(gganimate.dev_args = list(width = 600, height = 600))
p1_gif <- animate(p1, fps = 30, duration = 10, end_pause = 100)
# save.image('animate_least_squares.rda')
load('./animate_least_squares.rda')
### Plot the surface ############################################
param_path <- params %>%
mutate(pb1 = lag(b1), pb0 = lag(b0), ) %>%
mutate(pb0 = ifelse(is.na(pb0), b0, pb0),
pb1 = ifelse(is.na(pb1), b1, pb1))
eg <- expand.grid(intercept = seq(-20, 40, length.out = 100),
slope = seq(-2,3, length.out = 100))
eg$est <- lapply(1:nrow(eg), function(i) { squared_error(c(eg$slope[i], eg$intercept[i])) }) %>% unlist
ggplot() +
geom_tile(data = eg, aes(y = intercept, x = slope, fill = est)) +
scale_fill_viridis_c('Error', direction = -1) +
geom_contour(data = eg,
aes(y = intercept, x = slope, z = est),
color = 'black',
breaks = seq(min(eg$est), max(eg$est), length.out = 20),
alpha = 0.5) +
geom_segment(data = param_path, aes(xend = b1, yend = b0, x = pb1, y = pb0),
arrow = arrow(length = unit(0.1, 'cm'))) +
annotate('point', x = o$par[1], y = o$par[2], shape = 3, size = 5) +
theme_minimal()
p2 <- ggplot(param_path) +
geom_tile(data = eg, aes(y = intercept, x = slope, fill = est)) +
scale_fill_viridis_c('Error', direction = -1) +
geom_contour(data = eg, aes(y = intercept, x = slope, z = est),
color = 'black', alpha = 0.5,
breaks = seq(min(eg$est), max(eg$est), length.out = 20)) +
geom_segment(aes(xend = b1, yend = b0, x = pb1, y = pb0),
arrow = arrow(length = unit(0.1, 'cm'))) +
annotate('point', x = o$par[1], y = o$par[2], shape = 3, size = 5) +
coord_cartesian(xlim = c(-2, 3), ylim = c(-20, 40)) +
theme_minimal() +
theme(legend.position = 'none') +
transition_states(i, transition_length = 3, state_length = 0, wrap = F) +
shadow_wake(wake_length = 0.05) +
ease_aes('linear')
options(gganimate.dev_args = list(width = 600, height = 600))
p2_gif <- animate(p2, fps = 30, duration = 10, end_pause = 100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment