Created
October 30, 2020 22:15
-
-
Save daob/3fb154fe0c55f6bfef2a7c3fe46a2b97 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A small simulation to investigate bagging | |
# 2020-10-30 DLO | |
set.seed(4232) | |
# Root mean squared error metric | |
rmse <- function(y_pred, y) { | |
r <- y_pred - y | |
sqrt(mean(r^2)) | |
} | |
# Calculate a metric on a model and test set | |
metric_fit <- function(fit, dat_test, metric = rmse) { | |
y_pred <- predict(fit, newdata = dat_test) | |
metric(y_pred, dat_test$y) | |
} | |
# Train an (unstable) model | |
fit_mod <- function(dat) { | |
lm(y ~ .^3, data = dat) | |
} | |
# Draw one bootstrap sample and fit the model | |
boot_fit <- function(i_boot, dat) { | |
idx_boot <- sample(1:nrow(dat), size = nrow(dat), replace = TRUE) | |
dat_boot <- dat[idx_boot, ] | |
fit_mod(dat_boot) | |
} | |
# Data generating process (DGP), returns a data frame | |
dgp <- function(n, beta) { | |
J <- length(beta) | |
xpop <- matrix(rnorm(n) + rnorm(n * J), n) | |
ypop <- xpop %*% beta + rnorm(n) | |
data.frame(x = xpop, y = ypop) | |
} | |
# | |
sim <- function(it, n, n_boot, dat_pop, beta, progress_bar) { | |
# Update the reassuring progress bar with iteration number | |
setTxtProgressBar(progress_bar, it) | |
# Draw fresh sample from the DGP (=infinite population) | |
dat_sam <- dgp(n, beta = beta) | |
# Uncomment below for finite population setup | |
# idx_sam <- sample(1:nrow(dat_pop), size = n, replace = FALSE) | |
# dat_sam <- dat_pop[idx_sam, ] | |
# Fit the model without anything special | |
fit_sam <- fit_mod(dat_sam) | |
# Calculate metric for model without anything special | |
rmse_train <- metric_fit(fit = fit_sam, dat_test = dat_pop) | |
# Draw n_boot bootstrap samples and fit, returning a list of models | |
fit_boot <- lapply(1:n_boot, boot_fit, dat = dat_sam) | |
# Obtain predictions of the models in the list | |
y_pred_boot <- sapply(fit_boot, predict, newdata = dat_pop) | |
# Summarize predictions by averaging them over bootstrap samples | |
# (NOTE: using median instead of mean appears to improve performance a lot!) | |
y_pred_bagged <- apply(y_pred_boot, 1, median) | |
# Calculate metric for bagged predictions | |
rmse_bagged <- rmse(y_pred_bagged, dat_pop$y) | |
return(c(rmse_train = rmse_train, | |
rmse_bagged = rmse_bagged)) | |
} | |
# Number of features | |
J <- 8 | |
# True weights on features in linear predictor | |
beta <- runif(J, min = -5, max = +5) | |
# A "population" to use as test set, generated from DGP | |
dat_pop <- dgp(2e3, beta) # data frame | |
# Number of simulation replicates | |
nsim <- 50 | |
# Nice-looking progress bar, updated inside sim() | |
pb <- txtProgressBar(min = 1, max = nsim, style = 3) | |
# Run the simulation | |
system.time({ # time result | |
res <- t(sapply(1:nsim, # Number of replicates | |
sim, # Call simulation function | |
n = 100, # Sample size in each replicate | |
n_boot = 60, # No. bagged bootstrap samples | |
dat_pop = dat_pop, # Population used as test set | |
beta = beta, # True beta weights | |
progress_bar = pb # Reassuring progress bar | |
)) | |
}) | |
# Compare performance of bagging with regular fit: | |
summary(res) | |
# Bagging is better when: | |
# * Model is overspecified (e.g. .^3) | |
# * Sample size is low | |
# When model is very unstable, just a few bootstraps (5) already does the trick | |
# See how the two compare in different samples | |
limits <- c(min(as.vector(res)), quantile(as.vector(res), 0.9)) | |
plot(res, pch = 19, xlim = limits, ylim = limits, | |
main = "Bagged vs. Non-bagged over samples", | |
col = "#00000099") | |
abline(0, 1, col = "#66000077") | |
text(mean(limits), mean(limits)*0.96, labels = "Bagging is better", | |
adj = 0.5, srt = 45, col = "#00000077") | |
text(mean(limits), mean(limits)*1.04, labels = "Not bagging is better", | |
adj = 0.5, srt = 45, col = "#00000077") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment