Created
August 4, 2019 21:06
-
-
Save topepo/518bd25071e25807a32e07d7eb1bb30a to your computer and use it in GitHub Desktop.
Alternative infrastructure for finding and coordinating varying parameters in recipes and models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A prototype mechanism to collect the tunable parameters and their sources. It | |
# also allows for parameter with the same to be tuned. For example, you might | |
# model some variables with a spline but want to allow for different degrees | |
# of freedom by adding a different `step_ns()` for each variable. This interface | |
# allows the user to add an annotation for the parameter so that we can tell the | |
# different parameters apart. | |
# Limitations: | |
# - Currently, only one varying value is allowed per argument. For example, if | |
# some argument had a length two vector as its format, you can do something | |
# like `c(vary(), vary())`. | |
# - The last bit of code that generates the actual tuning grid assumes: | |
# * The code assumes that the parameter generating functions are in dials. | |
# * All parameters have a generating function. | |
# - We would need to revise the code that updates recipes and models with new | |
# parameter values so they can use the parameter annotations to find the | |
# right argument to change. | |
# ------------------------------------------------------------------------------ | |
# devtools::install_github("tidymodels/dials") | |
library(tidymodels) | |
library(rlang) | |
# ------------------------------------------------------------------------------ | |
# I used `vary()` instead of `varying()` here to avoid conflicting function names | |
# We can go back to `varying()` if this is the way to go | |
#' A placeholder function for argument values that are to be tuned. | |
#' | |
#' [vary()] is used when a parameter will be specified at a later date. | |
#' @param id A single character value that can be used to differentiate | |
#' parameters that are used in multiple places but have the same name or if | |
#' the user wants a note associated with the parameter. | |
#' @return A call object that echos the user input. | |
#' @examples | |
#' vary() | |
#' class(vary()) | |
#' vary("your name here") | |
#' | |
#' @export | |
vary <- function(id = "") { | |
if (!is.character(id) || length(id) != 1) { | |
stop("The `id` should be a single character string.", call. = FALSE) | |
} | |
if (id != "") { | |
res <- rlang::call2("vary", id) | |
} else { | |
res <- rlang::call2("vary") | |
} | |
res | |
} | |
#' @export | |
vary_args <- function(object, ...) { | |
UseMethod("vary_args") | |
} | |
#' Determine varying arguments | |
#' | |
#' `vary_args()` takes a model specification or a recipe and returns a tibble | |
#' of information on all possible varying arguments and whether or not they | |
#' are actually varying. | |
#' | |
#' The `source` column is determined differently depending on whether a `model_spec` | |
#' or a `recipe` is used (with additional detail on the type). | |
#' | |
#' The `id` field has any identifier that was passed to `vary()` (e.g. | |
#' `vary("some note")`). If not additional detail was used in that function, | |
#' the `id` field reverts to the name of the parameters. | |
#' | |
#' @param object A `model_spec` or a `recipe`. | |
#' @param full A single logical. Should all possible varying parameters be | |
#' returned? If `FALSE`, then only the parameters that | |
#' are actually varying are returned. | |
#' | |
#' @param ... Not currently used. | |
#' | |
#' @return A tibble with columns for the parameter name (`name`), whether it | |
#' contains _any_ varying value (`vary`), the `id` for the parameter (`id`), | |
#' and the information on where the parameter was located (`source`). | |
#' | |
#' @examples | |
#' | |
#' # List all possible varying args for the random forest spec | |
#' rand_forest() %>% vary_args() | |
#' | |
#' # mtry is now recognized as varying | |
#' rand_forest(mtry = vary()) %>% vary_args() | |
#' | |
#' # Even engine specific arguments can vary | |
#' rand_forest() %>% | |
#' set_engine("ranger", sample.fraction = vary()) %>% | |
#' vary_args() | |
#' | |
#' # List only the arguments that actually vary | |
#' rand_forest() %>% | |
#' set_engine("ranger", sample.fraction = vary()) %>% | |
#' vary_args(full = FALSE) | |
#' | |
#' rand_forest() %>% | |
#' set_engine( | |
#' "randomForest", | |
#' strata = Class, | |
#' sampsize = vary() | |
#' ) %>% | |
#' vary_args() | |
#' | |
#' @importFrom purrr map map_lgl | |
#' @export | |
vary_args.model_spec <- function(object, full = FALSE, ...) { | |
# use the model_spec top level class as the id | |
id <- class(object)[1] | |
if (length(object$args) == 0L & length(object$eng_args) == 0L) { | |
return(vary_tbl()) | |
} | |
# Locate varying args in spec args and engine specific args | |
object$args <- map(object$args, convert_args) | |
object$eng_args <- map(object$eng_args, convert_args) | |
arg_id <- map_chr(object$args, find_vary_id) | |
eng_arg_id <- map_chr(object$eng_args, find_vary_id) | |
res <- c(arg_id, eng_arg_id) | |
vary_tbl( | |
name = names(res), | |
varying = unname(!is.na(res)), | |
id = res, | |
source = paste("model_spec:", id), | |
full = full | |
) %>% | |
mutate(id = ifelse(id == "", name, id)) | |
} | |
# If we map over a list or arguments and some are quosures, we get the message | |
# that "Subsetting quosures with `[[` is deprecated as of rlang 0.4.0" | |
convert_args <- function(x) { | |
if (is_quosure(x)) { | |
x <- rlang::quo_get_expr(x) | |
} | |
x | |
} | |
#' @importFrom purrr map2_dfr map_chr | |
#' @export | |
#' @rdname vary_args.model_spec | |
vary_args.recipe <- function(object, full = FALSE, ...) { | |
steps <- object$steps | |
if (length(steps) == 0L) { | |
return(vary_tbl()) | |
} | |
map_dfr(object$steps, vary_args, full = full) | |
} | |
#' @importFrom purrr map map_lgl | |
#' @export | |
#' @rdname vary_args.model_spec | |
vary_args.step <- function(object, full = FALSE, ...) { | |
# Unique step id | |
id <- object$id | |
# Grab the step class before the subset, as that removes the class | |
step_type <- class(object)[1] | |
# Remove NULL argument steps. These are reserved | |
# for deprecated args or those set at prep() time. | |
object <- object[!map_lgl(object, is.null)] | |
# remove the non-varying arguments as they are not important | |
object <- object[!(names(object) %in% non_varying_step_arguments)] | |
# ensure the user didn't specify a non-varying argument as vary() | |
# TODO update this | |
# validate_only_allowed_step_args(res, step_type) | |
res <- map_chr(object, find_vary_id) | |
vary_tbl( | |
name = names(res), | |
varying = unname(!is.na(res)), | |
id = unname(res), | |
source = paste("recipe:", id), | |
full = full | |
) %>% | |
mutate(id = ifelse(id == "", name, id)) | |
} | |
# useful for standardization and for creating a 0 row varying tbl | |
# (i.e. for when there are no steps in a recipe) | |
vary_tbl <- function(name = character(), | |
varying = logical(), | |
id = character(), | |
source = character(), | |
full = FALSE) { | |
vry_tbl <- tibble( | |
name = name, | |
varying = varying, | |
id = id, | |
source = source | |
) | |
if (!full) { | |
vry_tbl <- vry_tbl[vry_tbl$varying,] | |
} | |
vry_tbl | |
} | |
validate_only_allowed_step_args <- function(x, step_type) { | |
check_allowed_arg <- function(x, nm) { | |
# not varying | |
if (rlang::is_false(x)) { | |
return(invisible(x)) | |
} | |
# not a non-varying step arg name | |
bad_nm <- nm %in% non_varying_step_arguments | |
if (!bad_nm) { | |
return(invisible(x)) | |
} | |
rlang::abort(glue::glue( | |
"The following argument for a recipe step of type ", | |
"'{step_type}' is not allowed to vary: '{nm}'." | |
)) | |
} | |
purrr::iwalk(x, check_allowed_arg) | |
invisible(x) | |
} | |
non_varying_step_arguments <- c( | |
'...', 'abbr', 'base', 'class', 'column', 'columns', 'convert', | |
'custom_token', 'data', 'default', 'denom', 'dictionary', 'features', | |
'func', 'id', 'impute_with', 'index', 'input', 'inputs', 'inverse', 'keep', | |
'key', 'label', 'language', 'lat', 'levels', 'limits', 'log', 'lon', 'mapping', | |
'max', 'means', 'medians', 'min', 'models', 'modes', 'na_rm', 'name', 'names', | |
'naming', 'new_level', 'norm', 'normalize', 'object', 'objects', 'options', | |
'ordinal', 'other', 'outcome', 'pattern', 'pct', 'percentage', 'predictors', | |
'prefix', 'preserve', 'profile', 'ranges', 'ratio', 'recipe', 'ref_data', | |
'ref_first', 'removals', 'replace', 'res', 'result', 'retain', 'reverse', | |
'role', 'sds', 'seed', 'seed_val', 'sep', 'skip', 'statistic', 'strict', | |
'sublinear_tf', 'target', 'terms', 'trained', 'transform', 'use', 'value', | |
'verbose', 'vocabulary', 'x', 'zero_based' | |
) | |
# helpers ---------------------------------------------------------------------- | |
# Return the `id` arg in vary(); if not specified, then returns "" or if not | |
# a varying arg then returns na_chr | |
vary_id <- function(x) { | |
if (is.null(x)) { | |
return(na_chr) | |
} else { | |
if (is_quosures(x)) { | |
# Try to evaluate to catch things in the global envir. | |
.x <- try(map(x, eval_tidy), silent = TRUE) | |
if (inherits(.x, "try-error")) { | |
x <- map(x, quo_get_expr) | |
} else { | |
x <- .x | |
} | |
if (is.null(x)) { | |
return(na_chr) | |
} | |
} | |
# `vary()` will always return a call object | |
if (is.call(x)) { | |
if (call_name(x) == "vary") { | |
# If an id was specified: | |
if (length(x) > 1) { | |
return(x[[2]]) | |
} else { | |
# no id | |
return("") | |
} | |
return(x$id) | |
} else { | |
return(na_chr) | |
} | |
} | |
} | |
na_chr | |
} | |
find_vary_id <- function(x) { | |
# STEP 1 - Early exits | |
# Early exit for empty elements (like list()) | |
if (length(x) == 0L) { | |
return(na_chr) | |
} | |
# turn quosures into expressions before continuing | |
if (is_quosures(x)) { | |
# Try to evaluate to catch things in the global envir. If it is a dplyr | |
# selector, it will fail to evaluate. | |
.x <- try(map(x, eval_tidy), silent = TRUE) | |
if (inherits(.x, "try-error")) { | |
x <- map(x, quo_get_expr) | |
} else { | |
x <- .x | |
} | |
} | |
id <- vary_id(x) | |
if (!is.na(id)) { | |
return(id) | |
} | |
if (is.atomic(x) | is.name(x) | length(x) == 1) { | |
return(na_chr) | |
} | |
# STEP 2 - Recursion | |
# varying_elems <- map_lgl(x, find_vary) | |
varying_elems <- vector("character", length = length(x)) | |
# use map_lgl | |
for (i in seq_along(x)) { | |
varying_elems[i] <- find_vary_id(x[[i]]) | |
} | |
varying_elems <- varying_elems[!is.na(varying_elems)] | |
if (length(varying_elems) == 0) { | |
varying_elems <- na_chr | |
} | |
if (sum(varying_elems == "", na.rm = TRUE) > 1) { | |
stop( | |
"Only one varying value is currently allowed per argument. ", | |
"The current argument has: `", | |
paste0(deparse(x), collapse = ""), | |
"`.", | |
call. = FALSE) | |
} | |
return(varying_elems) | |
} | |
# ------------------------------------------------------------------------------ | |
# Examples: | |
f_1 <- vary() | |
f_2 <- vary(id = "") | |
f_3 <- vary(id = "2") | |
f_4 <- vary("soup") | |
f_5 <- call2("lm", formula = Species ~ ., .ns = "stats") | |
f_6 <- call2("lm", formula = Species ~ ., model = vary(), .ns = "stats") | |
vary_id(f_1) | |
find_vary_id(f_1) | |
vary_id(f_2) | |
find_vary_id(f_2) | |
vary_id(f_3) | |
find_vary_id(f_3) | |
vary_id(f_4) | |
find_vary_id(f_4) | |
vary_id(f_5) | |
find_vary_id(f_5) | |
vary_id(f_6) # checks entire object | |
find_vary_id(f_6) # looks inside the elements of the object | |
# ------------------------------------------------------------------------------ | |
rand_forest() %>% vary_args() | |
rand_forest(mtry = vary()) %>% vary_args() | |
rand_forest(mtry = vary("rf_mtry")) %>% vary_args() | |
rand_forest() %>% | |
set_engine("ranger", sample.fraction = vary()) %>% | |
vary_args() | |
bar <- function() 2 | |
rand_forest() %>% | |
set_engine("ranger", sample.fraction = bar()) %>% | |
vary_args() | |
rand_forest(min_n = vary()) %>% | |
set_engine("ranger", sample.fraction = vary(id = "rf sampling param")) %>% | |
vary_args() | |
# should throw error for only 1 varying per argument | |
rand_forest() %>% | |
set_engine("ranger", strata = expr(Class), sampsize = c(vary(), vary())) %>% | |
vary_args() | |
model_ex <- | |
boost_tree( | |
mtry = vary(), trees = vary(), min_n = vary(), tree_depth = vary(), | |
sample_size = vary() | |
) %>% | |
set_engine("xgboost") %>% | |
vary_args() %>% | |
mutate( | |
arg = map(name, ~ call2(.x, .ns = "dials")), | |
arg = map(arg, eval_tidy), | |
arg = map(arg, finalize, x = lending_club) | |
) | |
set.seed(462) | |
grid_latin_hypercube(!!!model_ex$arg, size = 20) %>% | |
setNames(model_ex$id) | |
# ------------------------------------------------------------------------------ | |
recipe(Species ~ ., data = iris) %>% | |
step_bs(Sepal.Width, deg_free = 2, degree = vary("width degree")) %>% | |
step_bs(Sepal.Length, deg_free = 2, degree = vary("length degree")) %>% | |
step_pca(all_predictors(), num_comp = vary()) %>% | |
vary_args() | |
rec_example <- | |
recipe(Sepal.Length ~ ., data = iris) %>% | |
step_bs(Sepal.Width, deg_free = vary(), degree = vary("Sepal Width spline degree")) %>% | |
step_corr(Species, threshold = vary("corr threshold")) %>% | |
step_isomap(all_predictors(), num_terms = vary(), neighbors = vary("isomap K")) %>% | |
vary_args() | |
rec_example | |
rec_example <- | |
rec_example %>% | |
mutate( | |
arg = map(name, ~ call2(.x, .ns = "dials")), | |
arg = map(arg, eval_tidy), | |
arg = map(arg, finalize, x = iris) | |
) | |
set.seed(462) | |
grid_max_entropy(!!!rec_example$arg, size = 2^length(rec_example$arg)) %>% | |
setNames(rec_example$id) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment