Skip to content

Instantly share code, notes, and snippets.

@topepo
Created February 11, 2021 22:47
Show Gist options
  • Save topepo/9bd165222d49aa7ee54233de8812af5a to your computer and use it in GitHub Desktop.
Save topepo/9bd165222d49aa7ee54233de8812af5a to your computer and use it in GitHub Desktop.
LiblineaR, glmnet, and glm
# pak::pak("tidymodels/parsnip@logistic-liblinear")
library(AppliedPredictiveModeling)
library(tidymodels)
theme_set(theme_bw())
# ------------------------------------------------------------------------------
lr_pull <- function(pen, eng = "glmnet", dat, ...) {
logistic_reg(penalty = pen, ...) %>%
set_engine(eng) %>%
fit(class ~ ., dat) %>%
tidy() %>%
mutate(penalty = pen, engine = eng)
}
tidy.LiblineaR <- function(x, ...) {
# In general, this needs more work
coefs <- x$W[1,]
names(coefs)[names(coefs) == "Bias"] <- "(Intercept)"
tibble::tibble(term = names(coefs), estimate = unname(coefs))
}
# ------------------------------------------------------------------------------
# lp = 0 - 4 * x1 + 4 * x2
# cor[x1, x2] about 0.62
set.seed(1)
dat <- easyBoundaryFunc(1000, interaction = 0) %>% select(-prob)
ggplot(dat, aes(X1, X2, col = class)) +
geom_point(alpha = .3) +
coord_fixed(ratio = 1)
# ------------------------------------------------------------------------------
lasso_penalties <- 10^seq(-4, 3, length.out = 20)
ridge_penalties <- lasso_penalties
big_penalties <- lasso_penalties
glmn_lasso_res <-
map_dfr(lasso_penalties, lr_pull, dat = dat, mixture = 1) %>%
mutate(model = "lasso")
ll_lasso_res <-
map_dfr(big_penalties, lr_pull, eng = "LiblineaR", dat = dat, mixture = 1) %>%
mutate(model = "lasso")
glmn_ridge_res <-
map_dfr(ridge_penalties, lr_pull, dat = dat, mixture = 0) %>%
mutate(model = "ridge")
ll_ridge_res <-
map_dfr(big_penalties, lr_pull, eng = "LiblineaR", dat = dat, mixture = 0) %>%
mutate(model = "ridge")
glm_res <-
lr_pull(NULL, "glm", dat) %>%
select(term, estimate) %>%
mutate(penalty = 0, engine = "glm") %>%
mutate(model = "unpenalized")
glm_res
bind_rows(glmn_lasso_res, ll_lasso_res, glmn_ridge_res, ll_ridge_res) %>%
dplyr::filter(term != "(Intercept)") %>%
ggplot(aes(x = penalty, y = estimate, col = term)) +
geom_hline(yintercept = c(-4, 4), lty = 3) +
geom_line() +
facet_grid(engine ~ model) +
scale_x_log10()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment