Skip to content

Instantly share code, notes, and snippets.

@topepo
Created June 6, 2024 13:30
Show Gist options
  • Save topepo/5b57761f98d6ef1afa0b262fc870a7d7 to your computer and use it in GitHub Desktop.
Save topepo/5b57761f98d6ef1afa0b262fc870a7d7 to your computer and use it in GitHub Desktop.
nested resampling in tidymodels
# pak::pak(c("tidymodels/finetune@nested"), ask = FALSE)
library(tidymodels)
library(finetune)
library(rlang)
library(sfd)
library(doMC)
# ------------------------------------------------------------------------------
tidymodels_prefer()
theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)
registerDoMC(cores = 10)
# ------------------------------------------------------------------------------
f <- expr(-1 - 4 * A - 2 * B - 0.2 * A^2 + 1 * B^2)
x_seq <- seq(-4, 4, length.out = 100)
grid <-
crossing(A = seq(-3, 3, length.out = 100), B = x_seq) %>%
mutate(
lp = rlang::eval_tidy(f, data = .)
)
set.seed(943)
sim_tr <- sim_logistic(200, f)
sim_new <- sim_logistic(1000, f)
set.seed(14)
sim_rs <- vfold_cv(sim_tr)
set.seed(14)
sim_nested_rs <- nested_cv(sim_tr, outside = vfold_cv(), inside = vfold_cv())
# Same outer resamples
all.equal(
sim_rs$splits[[1]]$in_id,
sim_nested_rs$splits[[1]]$in_id)
# although...
identical(
sim_rs$splits[[1]]$data,
sim_nested_rs$inner_resamples[[1]]$splits[[1]]$data)
# ------------------------------------------------------------------------------
knn_spec <-
nearest_neighbor(neighbors = tune(), dist_power = tune(),
weight_func = "rectangular") %>%
set_mode("classification")
knn_wflow <-
workflow() %>%
add_model(knn_spec) %>%
add_formula(class ~ A + B)
# ------------------------------------------------------------------------------
sfd_size <- 15
knn_prm <- parameters(neighbors(c(2, 50)), dist_power())
vals <- map(knn_prm$object, ~ rep_len(value_seq(.x, sfd_size), length.out = sfd_size))
knn_sfd <-
get_design(2, sfd_size, type = "max_min_l1") %>%
update_values(vals) %>%
setNames(knn_prm$id)
knn_sfd |>
ggplot(aes(neighbors, dist_power)) +
geom_point()
# ------------------------------------------------------------------------------
knn_sfd_res <-
knn_wflow %>%
tune_grid(
resamples = sim_rs,
metrics = metric_set(brier_class),
grid = knn_sfd
)
autoplot(knn_sfd_res)
show_best(knn_sfd_res, metric = "brier_class", n = 1)
# ------------------------------------------------------------------------------
knn_sfd_nest_res <-
knn_wflow %>%
tune_nested(
resamples = sim_nested_rs,
metrics = metric_set(brier_class),
grid = knn_sfd
)
collect_metrics(knn_sfd_nest_res)
final_grid <-
knn_sfd_nest_res %>%
select(.selected) %>%
unnest(.selected) %>%
count(neighbors, dist_power) %>%
arrange(desc(n)) |>
mutate(method = "Grid Search")
final_grid
# ------------------------------------------------------------------------------
# Interative optimization
set.seed(97)
knn_sa_res <-
knn_wflow %>%
tune_sim_anneal(
resamples = sim_rs,
metrics = metric_set(brier_class),
initial = 4,
iter = 20,
param_info = knn_prm
)
autoplot(knn_sa_res)
autoplot(knn_sa_res, type = "parameters")
show_best(knn_sa_res, metric = "brier_class", n = 1)
# ------------------------------------------------------------------------------
set.seed(97)
knn_sa_nest_res <-
knn_wflow %>%
tune_nested(
resamples = sim_nested_rs,
fn = "tune_sim_anneal",
metrics = metric_set(brier_class),
initial = 4,
iter = 20,
param_info = knn_prm,
control = control_sim_anneal(verbose_iter = FALSE)
)
collect_metrics(knn_sa_nest_res)
final_sa <-
knn_sa_nest_res %>%
select(.selected) %>%
unnest(.selected) %>%
count(neighbors, dist_power) %>%
arrange(desc(n)) |>
mutate(method = "SA")
final_sa
bind_rows(final_grid, final_sa)|>
ggplot(aes(neighbors, dist_power, col = method, pch = method)) +
geom_point() +
lims(x = c(2, 50), y = 1:2) +
labs(x = neighbors()$label, y = dist_power()$label) +
theme(legend.position = "top")
@topepo
Copy link
Author

topepo commented Aug 1, 2024

Nested resampling does not optimize your model. Once you have optimized your model using functions like tune_grid() or tune_bayes(), nested resampling gives you a better performance estimate than those functions can produce.

In a way, it helps validate how you have already optimized your model.

The code is designed to mirror how you used the tune_*() functions. You use the same metrics and model arguments as those functions.

As I said, it is a prototype; the documentation will be written once we have the code and user interface worked out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment