Created
June 6, 2024 13:30
-
-
Save topepo/5b57761f98d6ef1afa0b262fc870a7d7 to your computer and use it in GitHub Desktop.
nested resampling in tidymodels
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pak::pak(c("tidymodels/finetune@nested"), ask = FALSE) | |
library(tidymodels) | |
library(finetune) | |
library(rlang) | |
library(sfd) | |
library(doMC) | |
# ------------------------------------------------------------------------------ | |
tidymodels_prefer() | |
theme_set(theme_bw()) | |
options(pillar.advice = FALSE, pillar.min_title_chars = Inf) | |
registerDoMC(cores = 10) | |
# ------------------------------------------------------------------------------ | |
f <- expr(-1 - 4 * A - 2 * B - 0.2 * A^2 + 1 * B^2) | |
x_seq <- seq(-4, 4, length.out = 100) | |
grid <- | |
crossing(A = seq(-3, 3, length.out = 100), B = x_seq) %>% | |
mutate( | |
lp = rlang::eval_tidy(f, data = .) | |
) | |
set.seed(943) | |
sim_tr <- sim_logistic(200, f) | |
sim_new <- sim_logistic(1000, f) | |
set.seed(14) | |
sim_rs <- vfold_cv(sim_tr) | |
set.seed(14) | |
sim_nested_rs <- nested_cv(sim_tr, outside = vfold_cv(), inside = vfold_cv()) | |
# Same outer resamples | |
all.equal( | |
sim_rs$splits[[1]]$in_id, | |
sim_nested_rs$splits[[1]]$in_id) | |
# although... | |
identical( | |
sim_rs$splits[[1]]$data, | |
sim_nested_rs$inner_resamples[[1]]$splits[[1]]$data) | |
# ------------------------------------------------------------------------------ | |
knn_spec <- | |
nearest_neighbor(neighbors = tune(), dist_power = tune(), | |
weight_func = "rectangular") %>% | |
set_mode("classification") | |
knn_wflow <- | |
workflow() %>% | |
add_model(knn_spec) %>% | |
add_formula(class ~ A + B) | |
# ------------------------------------------------------------------------------ | |
sfd_size <- 15 | |
knn_prm <- parameters(neighbors(c(2, 50)), dist_power()) | |
vals <- map(knn_prm$object, ~ rep_len(value_seq(.x, sfd_size), length.out = sfd_size)) | |
knn_sfd <- | |
get_design(2, sfd_size, type = "max_min_l1") %>% | |
update_values(vals) %>% | |
setNames(knn_prm$id) | |
knn_sfd |> | |
ggplot(aes(neighbors, dist_power)) + | |
geom_point() | |
# ------------------------------------------------------------------------------ | |
knn_sfd_res <- | |
knn_wflow %>% | |
tune_grid( | |
resamples = sim_rs, | |
metrics = metric_set(brier_class), | |
grid = knn_sfd | |
) | |
autoplot(knn_sfd_res) | |
show_best(knn_sfd_res, metric = "brier_class", n = 1) | |
# ------------------------------------------------------------------------------ | |
knn_sfd_nest_res <- | |
knn_wflow %>% | |
tune_nested( | |
resamples = sim_nested_rs, | |
metrics = metric_set(brier_class), | |
grid = knn_sfd | |
) | |
collect_metrics(knn_sfd_nest_res) | |
final_grid <- | |
knn_sfd_nest_res %>% | |
select(.selected) %>% | |
unnest(.selected) %>% | |
count(neighbors, dist_power) %>% | |
arrange(desc(n)) |> | |
mutate(method = "Grid Search") | |
final_grid | |
# ------------------------------------------------------------------------------ | |
# Interative optimization | |
set.seed(97) | |
knn_sa_res <- | |
knn_wflow %>% | |
tune_sim_anneal( | |
resamples = sim_rs, | |
metrics = metric_set(brier_class), | |
initial = 4, | |
iter = 20, | |
param_info = knn_prm | |
) | |
autoplot(knn_sa_res) | |
autoplot(knn_sa_res, type = "parameters") | |
show_best(knn_sa_res, metric = "brier_class", n = 1) | |
# ------------------------------------------------------------------------------ | |
set.seed(97) | |
knn_sa_nest_res <- | |
knn_wflow %>% | |
tune_nested( | |
resamples = sim_nested_rs, | |
fn = "tune_sim_anneal", | |
metrics = metric_set(brier_class), | |
initial = 4, | |
iter = 20, | |
param_info = knn_prm, | |
control = control_sim_anneal(verbose_iter = FALSE) | |
) | |
collect_metrics(knn_sa_nest_res) | |
final_sa <- | |
knn_sa_nest_res %>% | |
select(.selected) %>% | |
unnest(.selected) %>% | |
count(neighbors, dist_power) %>% | |
arrange(desc(n)) |> | |
mutate(method = "SA") | |
final_sa | |
bind_rows(final_grid, final_sa)|> | |
ggplot(aes(neighbors, dist_power, col = method, pch = method)) + | |
geom_point() + | |
lims(x = c(2, 50), y = 1:2) + | |
labs(x = neighbors()$label, y = dist_power()$label) + | |
theme(legend.position = "top") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nested resampling does not optimize your model. Once you have optimized your model using functions like
tune_grid()
ortune_bayes()
, nested resampling gives you a better performance estimate than those functions can produce.In a way, it helps validate how you have already optimized your model.
The code is designed to mirror how you used the
tune_*()
functions. You use the samemetrics
and model arguments as those functions.As I said, it is a prototype; the documentation will be written once we have the code and user interface worked out.