Skip to content

Instantly share code, notes, and snippets.

@topepo
Created October 30, 2019 13:16
Show Gist options
  • Save topepo/f734155e4e402ddd3ad3fb10b133dfe2 to your computer and use it in GitHub Desktop.
Save topepo/f734155e4e402ddd3ad3fb10b133dfe2 to your computer and use it in GitHub Desktop.
prototype code to update parsnip and recipe objects with final parameter values
finalize_model <- function(x, param) {
if (!inherits(x, "model_spec")) {
stop("`x` should be a parsnip model specification.")
}
# check for nrow > 1 and tibble/list
pset <- parameters(x)
if (tibble::is_tibble(param)) {
param <- as.list(param)
}
param <- param[names(param) %in% pset$id]
# make this a function:
discordant <- dplyr::filter(pset, id != name & id %in% names(param))
if (nrow(discordant) > 0) {
for (i in 1:nrow(discordant)) {
names(param)[ names(param) == discordant$id[i] ] <- discordant$name[i]
}
}
# end
rlang::exec(update, object = x, !!!param)
}
finalize_recipe <- function(x, param) {
if (!inherits(x, "recipe")) {
stop("`x` should be a recipe.")
}
# check for nrow > 1 and tibble/list
pset <-
dials::parameters(x) %>%
dplyr::filter(id %in% names(param) & source == "recipe")
if (tibble::is_tibble(param)) {
param <- as.list(param)
}
param <- param[names(param) %in% pset$id]
param <- param[pset$id]
pset <- split(pset, pset$component_id)
for (i in seq_along(pset)) {
x <- complete_steps(param[[i]], pset[[i]], x)
}
x
}
complete_steps <- function(param, pset, object) {
# find the corresponding step in the recipe
step_ids <- purrr::map_chr(object$steps, ~ .x$id)
step_index <- which(pset$component_id == step_ids)
tmp <- object$steps[[step_index]]
# if the id is not the argument name:
list_para <- list(param)
names(param) <- pset$name
tmp <- rlang::exec(update, object = tmp, !!!param)
object$steps[[step_index]] <- tmp
object
}
# ------------------------------------------------------------------------------
library(tidymodels)
library(tune)
library(mlbench)
data(PimaIndiansDiabetes)
# ------------------------------------------------------------------------------
set.seed(151)
pima_rs <- vfold_cv(PimaIndiansDiabetes, repeats = 1)
pca_rec <-
recipe(diabetes ~ ., data = PimaIndiansDiabetes) %>%
step_pca(all_predictors(), num_comp = tune("# components"))
tree_mod <-
decision_tree(cost_complexity = tune("cp"), min_n = tune()) %>%
set_mode("classification") %>%
set_engine("rpart")
roc_vals <- metric_set(roc_auc)
set.seed(3625)
pima_res <-
tune_grid(pca_rec,
tree_mod,
resamples = pima_rs,
metrics = roc_vals)
final_param <- select_best(pima_res)
final_rec <- finalize_recipe(pca_rec, final_param)
final_mod <- finalize_model(tree_mod, final_param)
@topepo
Copy link
Author

topepo commented Sep 22, 2020

This is already implemented in tune; look at those functions instead.

@Steviey
Copy link

Steviey commented Sep 22, 2020

R 4.x, Ubuntu 16.x LTS

My fault, thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment