Created
October 30, 2019 13:16
-
-
Save topepo/f734155e4e402ddd3ad3fb10b133dfe2 to your computer and use it in GitHub Desktop.
prototype code to update parsnip and recipe objects with final parameter values
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
finalize_model <- function(x, param) { | |
if (!inherits(x, "model_spec")) { | |
stop("`x` should be a parsnip model specification.") | |
} | |
# check for nrow > 1 and tibble/list | |
pset <- parameters(x) | |
if (tibble::is_tibble(param)) { | |
param <- as.list(param) | |
} | |
param <- param[names(param) %in% pset$id] | |
# make this a function: | |
discordant <- dplyr::filter(pset, id != name & id %in% names(param)) | |
if (nrow(discordant) > 0) { | |
for (i in 1:nrow(discordant)) { | |
names(param)[ names(param) == discordant$id[i] ] <- discordant$name[i] | |
} | |
} | |
# end | |
rlang::exec(update, object = x, !!!param) | |
} | |
finalize_recipe <- function(x, param) { | |
if (!inherits(x, "recipe")) { | |
stop("`x` should be a recipe.") | |
} | |
# check for nrow > 1 and tibble/list | |
pset <- | |
dials::parameters(x) %>% | |
dplyr::filter(id %in% names(param) & source == "recipe") | |
if (tibble::is_tibble(param)) { | |
param <- as.list(param) | |
} | |
param <- param[names(param) %in% pset$id] | |
param <- param[pset$id] | |
pset <- split(pset, pset$component_id) | |
for (i in seq_along(pset)) { | |
x <- complete_steps(param[[i]], pset[[i]], x) | |
} | |
x | |
} | |
complete_steps <- function(param, pset, object) { | |
# find the corresponding step in the recipe | |
step_ids <- purrr::map_chr(object$steps, ~ .x$id) | |
step_index <- which(pset$component_id == step_ids) | |
tmp <- object$steps[[step_index]] | |
# if the id is not the argument name: | |
list_para <- list(param) | |
names(param) <- pset$name | |
tmp <- rlang::exec(update, object = tmp, !!!param) | |
object$steps[[step_index]] <- tmp | |
object | |
} | |
# ------------------------------------------------------------------------------ | |
library(tidymodels) | |
library(tune) | |
library(mlbench) | |
data(PimaIndiansDiabetes) | |
# ------------------------------------------------------------------------------ | |
set.seed(151) | |
pima_rs <- vfold_cv(PimaIndiansDiabetes, repeats = 1) | |
pca_rec <- | |
recipe(diabetes ~ ., data = PimaIndiansDiabetes) %>% | |
step_pca(all_predictors(), num_comp = tune("# components")) | |
tree_mod <- | |
decision_tree(cost_complexity = tune("cp"), min_n = tune()) %>% | |
set_mode("classification") %>% | |
set_engine("rpart") | |
roc_vals <- metric_set(roc_auc) | |
set.seed(3625) | |
pima_res <- | |
tune_grid(pca_rec, | |
tree_mod, | |
resamples = pima_rs, | |
metrics = roc_vals) | |
final_param <- select_best(pima_res) | |
final_rec <- finalize_recipe(pca_rec, final_param) | |
final_mod <- finalize_model(tree_mod, final_param) | |
I'm writing those methods now. I might make it into the update()
methods but that would require some minor changes to parsnip
.
Any progress here? It seems to me select_best() misses step-tune-data. Am I wrong?
This is already implemented in tune
; look at those functions instead.
R 4.x, Ubuntu 16.x LTS
My fault, thank you.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Or it could be built into a
workflow
as you suggest in the vignette - then the final solution would be even more elegant. It could be something along these lines: