Skip to content

Instantly share code, notes, and snippets.

@topepo
Created June 6, 2024 13:30
Show Gist options
  • Save topepo/5b57761f98d6ef1afa0b262fc870a7d7 to your computer and use it in GitHub Desktop.
Save topepo/5b57761f98d6ef1afa0b262fc870a7d7 to your computer and use it in GitHub Desktop.
nested resampling in tidymodels
# pak::pak(c("tidymodels/finetune@nested"), ask = FALSE)
library(tidymodels)
library(finetune)
library(rlang)
library(sfd)
library(doMC)
# ------------------------------------------------------------------------------
tidymodels_prefer()
theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)
registerDoMC(cores = 10)
# ------------------------------------------------------------------------------
f <- expr(-1 - 4 * A - 2 * B - 0.2 * A^2 + 1 * B^2)
x_seq <- seq(-4, 4, length.out = 100)
grid <-
crossing(A = seq(-3, 3, length.out = 100), B = x_seq) %>%
mutate(
lp = rlang::eval_tidy(f, data = .)
)
set.seed(943)
sim_tr <- sim_logistic(200, f)
sim_new <- sim_logistic(1000, f)
set.seed(14)
sim_rs <- vfold_cv(sim_tr)
set.seed(14)
sim_nested_rs <- nested_cv(sim_tr, outside = vfold_cv(), inside = vfold_cv())
# Same outer resamples
all.equal(
sim_rs$splits[[1]]$in_id,
sim_nested_rs$splits[[1]]$in_id)
# although...
identical(
sim_rs$splits[[1]]$data,
sim_nested_rs$inner_resamples[[1]]$splits[[1]]$data)
# ------------------------------------------------------------------------------
knn_spec <-
nearest_neighbor(neighbors = tune(), dist_power = tune(),
weight_func = "rectangular") %>%
set_mode("classification")
knn_wflow <-
workflow() %>%
add_model(knn_spec) %>%
add_formula(class ~ A + B)
# ------------------------------------------------------------------------------
sfd_size <- 15
knn_prm <- parameters(neighbors(c(2, 50)), dist_power())
vals <- map(knn_prm$object, ~ rep_len(value_seq(.x, sfd_size), length.out = sfd_size))
knn_sfd <-
get_design(2, sfd_size, type = "max_min_l1") %>%
update_values(vals) %>%
setNames(knn_prm$id)
knn_sfd |>
ggplot(aes(neighbors, dist_power)) +
geom_point()
# ------------------------------------------------------------------------------
knn_sfd_res <-
knn_wflow %>%
tune_grid(
resamples = sim_rs,
metrics = metric_set(brier_class),
grid = knn_sfd
)
autoplot(knn_sfd_res)
show_best(knn_sfd_res, metric = "brier_class", n = 1)
# ------------------------------------------------------------------------------
knn_sfd_nest_res <-
knn_wflow %>%
tune_nested(
resamples = sim_nested_rs,
metrics = metric_set(brier_class),
grid = knn_sfd
)
collect_metrics(knn_sfd_nest_res)
final_grid <-
knn_sfd_nest_res %>%
select(.selected) %>%
unnest(.selected) %>%
count(neighbors, dist_power) %>%
arrange(desc(n)) |>
mutate(method = "Grid Search")
final_grid
# ------------------------------------------------------------------------------
# Interative optimization
set.seed(97)
knn_sa_res <-
knn_wflow %>%
tune_sim_anneal(
resamples = sim_rs,
metrics = metric_set(brier_class),
initial = 4,
iter = 20,
param_info = knn_prm
)
autoplot(knn_sa_res)
autoplot(knn_sa_res, type = "parameters")
show_best(knn_sa_res, metric = "brier_class", n = 1)
# ------------------------------------------------------------------------------
set.seed(97)
knn_sa_nest_res <-
knn_wflow %>%
tune_nested(
resamples = sim_nested_rs,
fn = "tune_sim_anneal",
metrics = metric_set(brier_class),
initial = 4,
iter = 20,
param_info = knn_prm,
control = control_sim_anneal(verbose_iter = FALSE)
)
collect_metrics(knn_sa_nest_res)
final_sa <-
knn_sa_nest_res %>%
select(.selected) %>%
unnest(.selected) %>%
count(neighbors, dist_power) %>%
arrange(desc(n)) |>
mutate(method = "SA")
final_sa
bind_rows(final_grid, final_sa)|>
ggplot(aes(neighbors, dist_power, col = method, pch = method)) +
geom_point() +
lims(x = c(2, 50), y = 1:2) +
labs(x = neighbors()$label, y = dist_power()$label) +
theme(legend.position = "top")
@DobraVila
Copy link

Is it possible to use the tune_nested() function, and to which package does it belong?

@topepo
Copy link
Author

topepo commented Jul 31, 2024

If you want something temporary, then sure. It is in a branch of finetune (https://github.com/tidymodels/finetune/tree/nested).

The problem is that it is a prototype and the API and syntax might change in the near future.

@DobraVila
Copy link

Hi Max, Thanks for your prompt response! I am currently working on a study using Tidymodels and need to implement nested cross-validation to prevent data leakage. I am struggling with tuning my model using nested_cv. I have reviewed the example provided on the website and a few other scripts I found on github, but with no luck. It seems like using the tune_nested function might not be the ideal solution for my case. If you don't mind, could you point me to any available examples or resources on how to tune hyperparameters for a classification model using nested_cv? All the examples I found were for regression models. Thanks a lot!

@topepo
Copy link
Author

topepo commented Aug 1, 2024

Nested resampling does not optimize your model. Once you have optimized your model using functions like tune_grid() or tune_bayes(), nested resampling gives you a better performance estimate than those functions can produce.

In a way, it helps validate how you have already optimized your model.

The code is designed to mirror how you used the tune_*() functions. You use the same metrics and model arguments as those functions.

As I said, it is a prototype; the documentation will be written once we have the code and user interface worked out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment