-
-
Save topepo/5b57761f98d6ef1afa0b262fc870a7d7 to your computer and use it in GitHub Desktop.
# pak::pak(c("tidymodels/finetune@nested"), ask = FALSE) | |
library(tidymodels) | |
library(finetune) | |
library(rlang) | |
library(sfd) | |
library(doMC) | |
# ------------------------------------------------------------------------------ | |
tidymodels_prefer() | |
theme_set(theme_bw()) | |
options(pillar.advice = FALSE, pillar.min_title_chars = Inf) | |
registerDoMC(cores = 10) | |
# ------------------------------------------------------------------------------ | |
f <- expr(-1 - 4 * A - 2 * B - 0.2 * A^2 + 1 * B^2) | |
x_seq <- seq(-4, 4, length.out = 100) | |
grid <- | |
crossing(A = seq(-3, 3, length.out = 100), B = x_seq) %>% | |
mutate( | |
lp = rlang::eval_tidy(f, data = .) | |
) | |
set.seed(943) | |
sim_tr <- sim_logistic(200, f) | |
sim_new <- sim_logistic(1000, f) | |
set.seed(14) | |
sim_rs <- vfold_cv(sim_tr) | |
set.seed(14) | |
sim_nested_rs <- nested_cv(sim_tr, outside = vfold_cv(), inside = vfold_cv()) | |
# Same outer resamples | |
all.equal( | |
sim_rs$splits[[1]]$in_id, | |
sim_nested_rs$splits[[1]]$in_id) | |
# although... | |
identical( | |
sim_rs$splits[[1]]$data, | |
sim_nested_rs$inner_resamples[[1]]$splits[[1]]$data) | |
# ------------------------------------------------------------------------------ | |
knn_spec <- | |
nearest_neighbor(neighbors = tune(), dist_power = tune(), | |
weight_func = "rectangular") %>% | |
set_mode("classification") | |
knn_wflow <- | |
workflow() %>% | |
add_model(knn_spec) %>% | |
add_formula(class ~ A + B) | |
# ------------------------------------------------------------------------------ | |
sfd_size <- 15 | |
knn_prm <- parameters(neighbors(c(2, 50)), dist_power()) | |
vals <- map(knn_prm$object, ~ rep_len(value_seq(.x, sfd_size), length.out = sfd_size)) | |
knn_sfd <- | |
get_design(2, sfd_size, type = "max_min_l1") %>% | |
update_values(vals) %>% | |
setNames(knn_prm$id) | |
knn_sfd |> | |
ggplot(aes(neighbors, dist_power)) + | |
geom_point() | |
# ------------------------------------------------------------------------------ | |
knn_sfd_res <- | |
knn_wflow %>% | |
tune_grid( | |
resamples = sim_rs, | |
metrics = metric_set(brier_class), | |
grid = knn_sfd | |
) | |
autoplot(knn_sfd_res) | |
show_best(knn_sfd_res, metric = "brier_class", n = 1) | |
# ------------------------------------------------------------------------------ | |
knn_sfd_nest_res <- | |
knn_wflow %>% | |
tune_nested( | |
resamples = sim_nested_rs, | |
metrics = metric_set(brier_class), | |
grid = knn_sfd | |
) | |
collect_metrics(knn_sfd_nest_res) | |
final_grid <- | |
knn_sfd_nest_res %>% | |
select(.selected) %>% | |
unnest(.selected) %>% | |
count(neighbors, dist_power) %>% | |
arrange(desc(n)) |> | |
mutate(method = "Grid Search") | |
final_grid | |
# ------------------------------------------------------------------------------ | |
# Interative optimization | |
set.seed(97) | |
knn_sa_res <- | |
knn_wflow %>% | |
tune_sim_anneal( | |
resamples = sim_rs, | |
metrics = metric_set(brier_class), | |
initial = 4, | |
iter = 20, | |
param_info = knn_prm | |
) | |
autoplot(knn_sa_res) | |
autoplot(knn_sa_res, type = "parameters") | |
show_best(knn_sa_res, metric = "brier_class", n = 1) | |
# ------------------------------------------------------------------------------ | |
set.seed(97) | |
knn_sa_nest_res <- | |
knn_wflow %>% | |
tune_nested( | |
resamples = sim_nested_rs, | |
fn = "tune_sim_anneal", | |
metrics = metric_set(brier_class), | |
initial = 4, | |
iter = 20, | |
param_info = knn_prm, | |
control = control_sim_anneal(verbose_iter = FALSE) | |
) | |
collect_metrics(knn_sa_nest_res) | |
final_sa <- | |
knn_sa_nest_res %>% | |
select(.selected) %>% | |
unnest(.selected) %>% | |
count(neighbors, dist_power) %>% | |
arrange(desc(n)) |> | |
mutate(method = "SA") | |
final_sa | |
bind_rows(final_grid, final_sa)|> | |
ggplot(aes(neighbors, dist_power, col = method, pch = method)) + | |
geom_point() + | |
lims(x = c(2, 50), y = 1:2) + | |
labs(x = neighbors()$label, y = dist_power()$label) + | |
theme(legend.position = "top") |
If you want something temporary, then sure. It is in a branch of finetune (https://github.com/tidymodels/finetune/tree/nested).
The problem is that it is a prototype and the API and syntax might change in the near future.
Hi Max, Thanks for your prompt response! I am currently working on a study using Tidymodels and need to implement nested cross-validation to prevent data leakage. I am struggling with tuning my model using nested_cv
. I have reviewed the example provided on the website and a few other scripts I found on github, but with no luck. It seems like using the tune_nested
function might not be the ideal solution for my case. If you don't mind, could you point me to any available examples or resources on how to tune hyperparameters for a classification model using nested_cv
? All the examples I found were for regression models. Thanks a lot!
Nested resampling does not optimize your model. Once you have optimized your model using functions like tune_grid()
or tune_bayes()
, nested resampling gives you a better performance estimate than those functions can produce.
In a way, it helps validate how you have already optimized your model.
The code is designed to mirror how you used the tune_*()
functions. You use the same metrics
and model arguments as those functions.
As I said, it is a prototype; the documentation will be written once we have the code and user interface worked out.
Is it possible to use the
tune_nested()
function, and to which package does it belong?