Created
June 6, 2024 13:30
-
-
Save topepo/5b57761f98d6ef1afa0b262fc870a7d7 to your computer and use it in GitHub Desktop.
nested resampling in tidymodels
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pak::pak(c("tidymodels/finetune@nested"), ask = FALSE) | |
library(tidymodels) | |
library(finetune) | |
library(rlang) | |
library(sfd) | |
library(doMC) | |
# ------------------------------------------------------------------------------ | |
tidymodels_prefer() | |
theme_set(theme_bw()) | |
options(pillar.advice = FALSE, pillar.min_title_chars = Inf) | |
registerDoMC(cores = 10) | |
# ------------------------------------------------------------------------------ | |
f <- expr(-1 - 4 * A - 2 * B - 0.2 * A^2 + 1 * B^2) | |
x_seq <- seq(-4, 4, length.out = 100) | |
grid <- | |
crossing(A = seq(-3, 3, length.out = 100), B = x_seq) %>% | |
mutate( | |
lp = rlang::eval_tidy(f, data = .) | |
) | |
set.seed(943) | |
sim_tr <- sim_logistic(200, f) | |
sim_new <- sim_logistic(1000, f) | |
set.seed(14) | |
sim_rs <- vfold_cv(sim_tr) | |
set.seed(14) | |
sim_nested_rs <- nested_cv(sim_tr, outside = vfold_cv(), inside = vfold_cv()) | |
# Same outer resamples | |
all.equal( | |
sim_rs$splits[[1]]$in_id, | |
sim_nested_rs$splits[[1]]$in_id) | |
# although... | |
identical( | |
sim_rs$splits[[1]]$data, | |
sim_nested_rs$inner_resamples[[1]]$splits[[1]]$data) | |
# ------------------------------------------------------------------------------ | |
knn_spec <- | |
nearest_neighbor(neighbors = tune(), dist_power = tune(), | |
weight_func = "rectangular") %>% | |
set_mode("classification") | |
knn_wflow <- | |
workflow() %>% | |
add_model(knn_spec) %>% | |
add_formula(class ~ A + B) | |
# ------------------------------------------------------------------------------ | |
sfd_size <- 15 | |
knn_prm <- parameters(neighbors(c(2, 50)), dist_power()) | |
vals <- map(knn_prm$object, ~ rep_len(value_seq(.x, sfd_size), length.out = sfd_size)) | |
knn_sfd <- | |
get_design(2, sfd_size, type = "max_min_l1") %>% | |
update_values(vals) %>% | |
setNames(knn_prm$id) | |
knn_sfd |> | |
ggplot(aes(neighbors, dist_power)) + | |
geom_point() | |
# ------------------------------------------------------------------------------ | |
knn_sfd_res <- | |
knn_wflow %>% | |
tune_grid( | |
resamples = sim_rs, | |
metrics = metric_set(brier_class), | |
grid = knn_sfd | |
) | |
autoplot(knn_sfd_res) | |
show_best(knn_sfd_res, metric = "brier_class", n = 1) | |
# ------------------------------------------------------------------------------ | |
knn_sfd_nest_res <- | |
knn_wflow %>% | |
tune_nested( | |
resamples = sim_nested_rs, | |
metrics = metric_set(brier_class), | |
grid = knn_sfd | |
) | |
collect_metrics(knn_sfd_nest_res) | |
final_grid <- | |
knn_sfd_nest_res %>% | |
select(.selected) %>% | |
unnest(.selected) %>% | |
count(neighbors, dist_power) %>% | |
arrange(desc(n)) |> | |
mutate(method = "Grid Search") | |
final_grid | |
# ------------------------------------------------------------------------------ | |
# Interative optimization | |
set.seed(97) | |
knn_sa_res <- | |
knn_wflow %>% | |
tune_sim_anneal( | |
resamples = sim_rs, | |
metrics = metric_set(brier_class), | |
initial = 4, | |
iter = 20, | |
param_info = knn_prm | |
) | |
autoplot(knn_sa_res) | |
autoplot(knn_sa_res, type = "parameters") | |
show_best(knn_sa_res, metric = "brier_class", n = 1) | |
# ------------------------------------------------------------------------------ | |
set.seed(97) | |
knn_sa_nest_res <- | |
knn_wflow %>% | |
tune_nested( | |
resamples = sim_nested_rs, | |
fn = "tune_sim_anneal", | |
metrics = metric_set(brier_class), | |
initial = 4, | |
iter = 20, | |
param_info = knn_prm, | |
control = control_sim_anneal(verbose_iter = FALSE) | |
) | |
collect_metrics(knn_sa_nest_res) | |
final_sa <- | |
knn_sa_nest_res %>% | |
select(.selected) %>% | |
unnest(.selected) %>% | |
count(neighbors, dist_power) %>% | |
arrange(desc(n)) |> | |
mutate(method = "SA") | |
final_sa | |
bind_rows(final_grid, final_sa)|> | |
ggplot(aes(neighbors, dist_power, col = method, pch = method)) + | |
geom_point() + | |
lims(x = c(2, 50), y = 1:2) + | |
labs(x = neighbors()$label, y = dist_power()$label) + | |
theme(legend.position = "top") |
Nested resampling does not optimize your model. Once you have optimized your model using functions like tune_grid()
or tune_bayes()
, nested resampling gives you a better performance estimate than those functions can produce.
In a way, it helps validate how you have already optimized your model.
The code is designed to mirror how you used the tune_*()
functions. You use the same metrics
and model arguments as those functions.
As I said, it is a prototype; the documentation will be written once we have the code and user interface worked out.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Max, Thanks for your prompt response! I am currently working on a study using Tidymodels and need to implement nested cross-validation to prevent data leakage. I am struggling with tuning my model using
nested_cv
. I have reviewed the example provided on the website and a few other scripts I found on github, but with no luck. It seems like using thetune_nested
function might not be the ideal solution for my case. If you don't mind, could you point me to any available examples or resources on how to tune hyperparameters for a classification model usingnested_cv
? All the examples I found were for regression models. Thanks a lot!