Created
October 30, 2019 13:16
-
-
Save topepo/f734155e4e402ddd3ad3fb10b133dfe2 to your computer and use it in GitHub Desktop.
prototype code to update parsnip and recipe objects with final parameter values
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
finalize_model <- function(x, param) { | |
if (!inherits(x, "model_spec")) { | |
stop("`x` should be a parsnip model specification.") | |
} | |
# check for nrow > 1 and tibble/list | |
pset <- parameters(x) | |
if (tibble::is_tibble(param)) { | |
param <- as.list(param) | |
} | |
param <- param[names(param) %in% pset$id] | |
# make this a function: | |
discordant <- dplyr::filter(pset, id != name & id %in% names(param)) | |
if (nrow(discordant) > 0) { | |
for (i in 1:nrow(discordant)) { | |
names(param)[ names(param) == discordant$id[i] ] <- discordant$name[i] | |
} | |
} | |
# end | |
rlang::exec(update, object = x, !!!param) | |
} | |
finalize_recipe <- function(x, param) { | |
if (!inherits(x, "recipe")) { | |
stop("`x` should be a recipe.") | |
} | |
# check for nrow > 1 and tibble/list | |
pset <- | |
dials::parameters(x) %>% | |
dplyr::filter(id %in% names(param) & source == "recipe") | |
if (tibble::is_tibble(param)) { | |
param <- as.list(param) | |
} | |
param <- param[names(param) %in% pset$id] | |
param <- param[pset$id] | |
pset <- split(pset, pset$component_id) | |
for (i in seq_along(pset)) { | |
x <- complete_steps(param[[i]], pset[[i]], x) | |
} | |
x | |
} | |
complete_steps <- function(param, pset, object) { | |
# find the corresponding step in the recipe | |
step_ids <- purrr::map_chr(object$steps, ~ .x$id) | |
step_index <- which(pset$component_id == step_ids) | |
tmp <- object$steps[[step_index]] | |
# if the id is not the argument name: | |
list_para <- list(param) | |
names(param) <- pset$name | |
tmp <- rlang::exec(update, object = tmp, !!!param) | |
object$steps[[step_index]] <- tmp | |
object | |
} | |
# ------------------------------------------------------------------------------ | |
library(tidymodels) | |
library(tune) | |
library(mlbench) | |
data(PimaIndiansDiabetes) | |
# ------------------------------------------------------------------------------ | |
set.seed(151) | |
pima_rs <- vfold_cv(PimaIndiansDiabetes, repeats = 1) | |
pca_rec <- | |
recipe(diabetes ~ ., data = PimaIndiansDiabetes) %>% | |
step_pca(all_predictors(), num_comp = tune("# components")) | |
tree_mod <- | |
decision_tree(cost_complexity = tune("cp"), min_n = tune()) %>% | |
set_mode("classification") %>% | |
set_engine("rpart") | |
roc_vals <- metric_set(roc_auc) | |
set.seed(3625) | |
pima_res <- | |
tune_grid(pca_rec, | |
tree_mod, | |
resamples = pima_rs, | |
metrics = roc_vals) | |
final_param <- select_best(pima_res) | |
final_rec <- finalize_recipe(pca_rec, final_param) | |
final_mod <- finalize_model(tree_mod, final_param) | |
R 4.x, Ubuntu 16.x LTS
My fault, thank you.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is already implemented in
tune
; look at those functions instead.