Created
September 25, 2023 00:58
-
-
Save benjaminrich/a0b5b1e6cbd269678cd5e90a90268aa6 to your computer and use it in GitHub Desktop.
Stepwise regression
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
stepwise_forward <- function(base_fit, candidates, alpha=0.05, ...) UseMethod("stepwise_forward") | |
stepwise_backward <- function(base_fit, candidates, alpha=0.01, ...) UseMethod("stepwise_backward") | |
forward_step <- function(base_fit, candidates, alpha=0.05, ...) UseMethod("forward_step") | |
backward_step <- function(base_fit, candidates, alpha=0.01, ...) UseMethod("backward_step") | |
fit_all_models <- function(base_fit, all_formulas, ...) UseMethod("fit_all_models") | |
model_table <- function(obj, ...) UseMethod("model_table") | |
pvalue <- function(obj, ...) UseMethod("pvalue") | |
selected <- function(obj, ...) UseMethod("selected") | |
final_model <- function(obj, ...) UseMethod("final_model") | |
pvalue.default <- function(obj, ...) attr(obj, "pvalue", exact=TRUE) | |
selected.default <- function(obj, ...) attr(obj, "selected", exact=TRUE) | |
final_model.default <- function(obj, ...) attr(obj, "final_model", exact=TRUE) | |
stepwise_search <- function(base_fit, candidates, alpha, step_fn) { | |
best_fit <- base_fit | |
res <- list() | |
while (TRUE) { | |
if (length(candidates) == 0) break | |
step <- step_fn(best_fit, candidates, alpha) | |
res <- c(res, list(step)) | |
if (is.null(selected(step))) break | |
best_fit <- final_model(step) | |
candidates <- setdiff(candidates, selected(step)) | |
} | |
structure(list(res), | |
class = "stepwise_search", | |
selected = unlist(lapply(res, selected)), | |
final_model = best_fit | |
) | |
} | |
stepwise_forward.default <- function(base_fit, candidates, alpha=0.05) { | |
res <- stepwise_search( | |
base_fit = base_fit, | |
candidates = candidates, | |
alpha = alpha, | |
step_fn = forward_step | |
) | |
structure(setNames(res, "forward"), | |
class = "stepwise_forward", | |
selected = list(forward=selected(res))) | |
} | |
stepwise_backward.default <- function(base_fit, candidates, alpha=0.01) { | |
res <- stepwise_search( | |
base_fit = base_fit, | |
candidates = candidates, | |
alpha = alpha, | |
step_fn = backward_step | |
) | |
structure(setNames(res, "backward"), | |
class = "stepwise_backward", | |
selected = list(backward=selected(res))) | |
} | |
stepwise_backward.stepwise_forward <- function(base_fit, candidates, alpha=0.01) { | |
if (missing(candidates)) { | |
candidates <- unlist(selected(base_fit), use.names=F) | |
} | |
back_fit <- stepwise_backward(final_model(base_fit), candidates, alpha) | |
structure(c(base_fit, back_fit), | |
class = "stepwise_forward_backward", | |
selected = c(selected(base_fit), selected(back_fit)), | |
final_model = final_model(back_fit) | |
) | |
} | |
generic_step <- function( | |
base_fit, | |
candidates, | |
alpha, | |
direction = c("forward", "backward"), | |
op = if (direction=="forward") `<` else `>=`, | |
... | |
) { | |
direction <- match.arg(direction) | |
all_formulas <- derive_all_formulas(base_fit, base_formula=formula(base_fit), | |
add = if (direction=="forward") candidates else NULL, | |
subtract = if (direction=="backward") candidates else NULL | |
) | |
all_fits <- fit_all_models(base_fit, all_formulas, data=base_fit$data, ...) | |
mtab <- model_table(all_fits, base_fit, direction=direction, sort=TRUE) | |
pval <- pvalue(mtab) | |
i <- if (direction=="forward") which.min(pval) else which.max(pval) | |
if (op(pval[i], alpha)) { | |
selected <- names(pval)[i] | |
final_model <- all_fits[[selected]] | |
} else { | |
selected <- NULL | |
final_model <- base_fit | |
} | |
structure(mtab, | |
class = class(mtab), | |
base_fit = base_fit, | |
all_fits = all_fits, | |
selected = selected, | |
final_model = final_model) | |
} | |
forward_step.default <- function(base_fit, candidates, alpha=0.05, ...) { | |
res <- generic_step( | |
base_fit = base_fit, | |
candidates = candidates, | |
alpha = alpha, | |
direction = "forward", | |
... | |
) | |
structure(res, class=c("forward_step", class(res))) | |
} | |
backward_step.default <- function(base_fit, candidates, alpha=0.01, ...) { | |
res <- generic_step( | |
base_fit = base_fit, | |
candidates = candidates, | |
alpha = alpha, | |
direction = "backward", | |
... | |
) | |
structure(res, class=c("backward_step", class(res))) | |
} | |
get_names <- function(...) { | |
`%||%` <- function(a, b) if (is.null(a)) b else a | |
lapply(list(...), function(x) names(x) %||% as.character(x)) | |
} | |
derive_all_formulas <- function( | |
base_fit, | |
base_formula = formula(base_fit), | |
add = NULL, | |
subtract = NULL, | |
formula_names = unlist(get_names(add, subtract)) | |
) { | |
.add <- if (!is.null(add)) paste0("+", add) else NULL | |
.subtract <- if (!is.null(subtract)) paste0("-", subtract) else NULL | |
paste0(".~.", c(.add, .subtract)) |> | |
lapply(update.formula, old=base_formula) |> | |
setNames(formula_names) | |
} | |
fit_all_models.default <- function( | |
base_fit, | |
all_formulas, | |
data = base_fit$data, | |
model_names = names(all_formulas), | |
... | |
) { | |
lapply(all_formulas, function(x) { | |
update(base_fit, formula.=x, data=data, ...) | |
}) |> setNames(model_names) | |
} | |
model_table.default <- function( | |
all_fits, | |
base_fit, | |
alpha, | |
direction = c("forward", "backward"), | |
sort = TRUE, | |
decreasing = FALSE, | |
... | |
) { | |
f <- function(x) ifelse(direction=="forward", x, -x) | |
mtab <- lapply(all_fits, function(x) { | |
`-2*loglik` <- -2*as.numeric(logLik(x)) | |
`df` <- attr(logLik(x), "df", exact=TRUE) | |
`Base(-2*loglik)` <- -2*as.numeric(logLik(base_fit)) | |
`Base(df)` <- attr(logLik(base_fit), "df", exact=TRUE) | |
`Δ(-2*loglik)` <- f(`Base(-2*loglik)` - `-2*loglik`) | |
`Δdf` <- f(`df` - `Base(df)`) | |
`P-value` <- pchisq(`Δ(-2*loglik)`, `Δdf`, lower.tail=FALSE) | |
data.frame(check.names=FALSE, | |
`Model` = NA, | |
`-2*loglik`, | |
`df`, | |
`Base(-2*loglik)`, | |
`Base(df)`, | |
`Δ(-2*loglik)`, | |
`Δdf`, | |
`P-value` | |
) | |
}) |> do.call(what=rbind) | |
mtab$`Model` <- names(all_fits) | |
if (sort) { | |
mtab <- mtab[order(mtab$`P-value`, decreasing=decreasing),] | |
} | |
structure(mtab, | |
class = c("model_table", class(mtab)), | |
all_fits = all_fits, | |
base_fit = base_fit, | |
pvalue = setNames(mtab$`P-value`, mtab$`Model`) | |
) | |
} | |
if (FALSE) { | |
library(mvtnorm) | |
set.seed(123) | |
n <- 100 | |
p <- 4 | |
S <- rWishart(1, p, toeplitz(1/(1:p)))[,,1] | |
x <- rmvnorm(n, rep(0, p), S) | |
dat <- data.frame( | |
x1 = x[,1], | |
x2 = x[,2], | |
x3 = x[,3], | |
x4 = x[,4], | |
x5 = rnorm(n), | |
x6 = rnorm(n), | |
x7 = rnorm(n), | |
x8 = rnorm(n), | |
x9 = rnorm(n) | |
) | |
dat$y <- with(dat, 6 + 0.3*x1 + 0.1*x2 + 0.4*x3 + 0.2*x4 + rnorm(n, 0, 1.3)) | |
base_fit <- glm(y ~ x1, data=dat) | |
candidates <- c( | |
"x2", | |
"x3", | |
"x4", | |
"x5", | |
"x6", | |
"x7", | |
"x8", | |
"x9" | |
) | |
x <- forward_step(base_fit, candidates, alpha=0.05) | |
x | |
final_model(x) | |
x <- backward_step(base_fit, "x1", alpha=0.01) | |
x | |
final_model(x) | |
x <- stepwise_forward(base_fit, candidates, alpha=0.05) | |
x | |
selected(x) | |
final_model(x) | |
y <- stepwise_backward(base_fit, c("x3"), alpha=0.01) | |
y | |
selected(y) | |
final_model(y) | |
y <- stepwise_backward(x, alpha=0.01) | |
y | |
selected(y) | |
final_model(y) | |
x <- base_fit |> | |
stepwise_forward(candidates, alpha=0.05) |> | |
stepwise_backward(alpha=0.01) | |
x | |
selected(x) | |
final_model(x) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment