Skip to content

Instantly share code, notes, and snippets.

@topepo
Created August 4, 2019 21:06
Show Gist options
  • Save topepo/518bd25071e25807a32e07d7eb1bb30a to your computer and use it in GitHub Desktop.
Save topepo/518bd25071e25807a32e07d7eb1bb30a to your computer and use it in GitHub Desktop.
Alternative infrastructure for finding and coordinating varying parameters in recipes and models
# A prototype mechanism to collect the tunable parameters and their sources. It
# also allows for parameter with the same to be tuned. For example, you might
# model some variables with a spline but want to allow for different degrees
# of freedom by adding a different `step_ns()` for each variable. This interface
# allows the user to add an annotation for the parameter so that we can tell the
# different parameters apart.
# Limitations:
# - Currently, only one varying value is allowed per argument. For example, if
# some argument had a length two vector as its format, you can do something
# like `c(vary(), vary())`.
# - The last bit of code that generates the actual tuning grid assumes:
# * The code assumes that the parameter generating functions are in dials.
# * All parameters have a generating function.
# - We would need to revise the code that updates recipes and models with new
# parameter values so they can use the parameter annotations to find the
# right argument to change.
# ------------------------------------------------------------------------------
# devtools::install_github("tidymodels/dials")
library(tidymodels)
library(rlang)
# ------------------------------------------------------------------------------
# I used `vary()` instead of `varying()` here to avoid conflicting function names
# We can go back to `varying()` if this is the way to go
#' A placeholder function for argument values that are to be tuned.
#'
#' [vary()] is used when a parameter will be specified at a later date.
#' @param id A single character value that can be used to differentiate
#' parameters that are used in multiple places but have the same name or if
#' the user wants a note associated with the parameter.
#' @return A call object that echos the user input.
#' @examples
#' vary()
#' class(vary())
#' vary("your name here")
#'
#' @export
vary <- function(id = "") {
if (!is.character(id) || length(id) != 1) {
stop("The `id` should be a single character string.", call. = FALSE)
}
if (id != "") {
res <- rlang::call2("vary", id)
} else {
res <- rlang::call2("vary")
}
res
}
#' @export
vary_args <- function(object, ...) {
UseMethod("vary_args")
}
#' Determine varying arguments
#'
#' `vary_args()` takes a model specification or a recipe and returns a tibble
#' of information on all possible varying arguments and whether or not they
#' are actually varying.
#'
#' The `source` column is determined differently depending on whether a `model_spec`
#' or a `recipe` is used (with additional detail on the type).
#'
#' The `id` field has any identifier that was passed to `vary()` (e.g.
#' `vary("some note")`). If not additional detail was used in that function,
#' the `id` field reverts to the name of the parameters.
#'
#' @param object A `model_spec` or a `recipe`.
#' @param full A single logical. Should all possible varying parameters be
#' returned? If `FALSE`, then only the parameters that
#' are actually varying are returned.
#'
#' @param ... Not currently used.
#'
#' @return A tibble with columns for the parameter name (`name`), whether it
#' contains _any_ varying value (`vary`), the `id` for the parameter (`id`),
#' and the information on where the parameter was located (`source`).
#'
#' @examples
#'
#' # List all possible varying args for the random forest spec
#' rand_forest() %>% vary_args()
#'
#' # mtry is now recognized as varying
#' rand_forest(mtry = vary()) %>% vary_args()
#'
#' # Even engine specific arguments can vary
#' rand_forest() %>%
#' set_engine("ranger", sample.fraction = vary()) %>%
#' vary_args()
#'
#' # List only the arguments that actually vary
#' rand_forest() %>%
#' set_engine("ranger", sample.fraction = vary()) %>%
#' vary_args(full = FALSE)
#'
#' rand_forest() %>%
#' set_engine(
#' "randomForest",
#' strata = Class,
#' sampsize = vary()
#' ) %>%
#' vary_args()
#'
#' @importFrom purrr map map_lgl
#' @export
vary_args.model_spec <- function(object, full = FALSE, ...) {
# use the model_spec top level class as the id
id <- class(object)[1]
if (length(object$args) == 0L & length(object$eng_args) == 0L) {
return(vary_tbl())
}
# Locate varying args in spec args and engine specific args
object$args <- map(object$args, convert_args)
object$eng_args <- map(object$eng_args, convert_args)
arg_id <- map_chr(object$args, find_vary_id)
eng_arg_id <- map_chr(object$eng_args, find_vary_id)
res <- c(arg_id, eng_arg_id)
vary_tbl(
name = names(res),
varying = unname(!is.na(res)),
id = res,
source = paste("model_spec:", id),
full = full
) %>%
mutate(id = ifelse(id == "", name, id))
}
# If we map over a list or arguments and some are quosures, we get the message
# that "Subsetting quosures with `[[` is deprecated as of rlang 0.4.0"
convert_args <- function(x) {
if (is_quosure(x)) {
x <- rlang::quo_get_expr(x)
}
x
}
#' @importFrom purrr map2_dfr map_chr
#' @export
#' @rdname vary_args.model_spec
vary_args.recipe <- function(object, full = FALSE, ...) {
steps <- object$steps
if (length(steps) == 0L) {
return(vary_tbl())
}
map_dfr(object$steps, vary_args, full = full)
}
#' @importFrom purrr map map_lgl
#' @export
#' @rdname vary_args.model_spec
vary_args.step <- function(object, full = FALSE, ...) {
# Unique step id
id <- object$id
# Grab the step class before the subset, as that removes the class
step_type <- class(object)[1]
# Remove NULL argument steps. These are reserved
# for deprecated args or those set at prep() time.
object <- object[!map_lgl(object, is.null)]
# remove the non-varying arguments as they are not important
object <- object[!(names(object) %in% non_varying_step_arguments)]
# ensure the user didn't specify a non-varying argument as vary()
# TODO update this
# validate_only_allowed_step_args(res, step_type)
res <- map_chr(object, find_vary_id)
vary_tbl(
name = names(res),
varying = unname(!is.na(res)),
id = unname(res),
source = paste("recipe:", id),
full = full
) %>%
mutate(id = ifelse(id == "", name, id))
}
# useful for standardization and for creating a 0 row varying tbl
# (i.e. for when there are no steps in a recipe)
vary_tbl <- function(name = character(),
varying = logical(),
id = character(),
source = character(),
full = FALSE) {
vry_tbl <- tibble(
name = name,
varying = varying,
id = id,
source = source
)
if (!full) {
vry_tbl <- vry_tbl[vry_tbl$varying,]
}
vry_tbl
}
validate_only_allowed_step_args <- function(x, step_type) {
check_allowed_arg <- function(x, nm) {
# not varying
if (rlang::is_false(x)) {
return(invisible(x))
}
# not a non-varying step arg name
bad_nm <- nm %in% non_varying_step_arguments
if (!bad_nm) {
return(invisible(x))
}
rlang::abort(glue::glue(
"The following argument for a recipe step of type ",
"'{step_type}' is not allowed to vary: '{nm}'."
))
}
purrr::iwalk(x, check_allowed_arg)
invisible(x)
}
non_varying_step_arguments <- c(
'...', 'abbr', 'base', 'class', 'column', 'columns', 'convert',
'custom_token', 'data', 'default', 'denom', 'dictionary', 'features',
'func', 'id', 'impute_with', 'index', 'input', 'inputs', 'inverse', 'keep',
'key', 'label', 'language', 'lat', 'levels', 'limits', 'log', 'lon', 'mapping',
'max', 'means', 'medians', 'min', 'models', 'modes', 'na_rm', 'name', 'names',
'naming', 'new_level', 'norm', 'normalize', 'object', 'objects', 'options',
'ordinal', 'other', 'outcome', 'pattern', 'pct', 'percentage', 'predictors',
'prefix', 'preserve', 'profile', 'ranges', 'ratio', 'recipe', 'ref_data',
'ref_first', 'removals', 'replace', 'res', 'result', 'retain', 'reverse',
'role', 'sds', 'seed', 'seed_val', 'sep', 'skip', 'statistic', 'strict',
'sublinear_tf', 'target', 'terms', 'trained', 'transform', 'use', 'value',
'verbose', 'vocabulary', 'x', 'zero_based'
)
# helpers ----------------------------------------------------------------------
# Return the `id` arg in vary(); if not specified, then returns "" or if not
# a varying arg then returns na_chr
vary_id <- function(x) {
if (is.null(x)) {
return(na_chr)
} else {
if (is_quosures(x)) {
# Try to evaluate to catch things in the global envir.
.x <- try(map(x, eval_tidy), silent = TRUE)
if (inherits(.x, "try-error")) {
x <- map(x, quo_get_expr)
} else {
x <- .x
}
if (is.null(x)) {
return(na_chr)
}
}
# `vary()` will always return a call object
if (is.call(x)) {
if (call_name(x) == "vary") {
# If an id was specified:
if (length(x) > 1) {
return(x[[2]])
} else {
# no id
return("")
}
return(x$id)
} else {
return(na_chr)
}
}
}
na_chr
}
find_vary_id <- function(x) {
# STEP 1 - Early exits
# Early exit for empty elements (like list())
if (length(x) == 0L) {
return(na_chr)
}
# turn quosures into expressions before continuing
if (is_quosures(x)) {
# Try to evaluate to catch things in the global envir. If it is a dplyr
# selector, it will fail to evaluate.
.x <- try(map(x, eval_tidy), silent = TRUE)
if (inherits(.x, "try-error")) {
x <- map(x, quo_get_expr)
} else {
x <- .x
}
}
id <- vary_id(x)
if (!is.na(id)) {
return(id)
}
if (is.atomic(x) | is.name(x) | length(x) == 1) {
return(na_chr)
}
# STEP 2 - Recursion
# varying_elems <- map_lgl(x, find_vary)
varying_elems <- vector("character", length = length(x))
# use map_lgl
for (i in seq_along(x)) {
varying_elems[i] <- find_vary_id(x[[i]])
}
varying_elems <- varying_elems[!is.na(varying_elems)]
if (length(varying_elems) == 0) {
varying_elems <- na_chr
}
if (sum(varying_elems == "", na.rm = TRUE) > 1) {
stop(
"Only one varying value is currently allowed per argument. ",
"The current argument has: `",
paste0(deparse(x), collapse = ""),
"`.",
call. = FALSE)
}
return(varying_elems)
}
# ------------------------------------------------------------------------------
# Examples:
f_1 <- vary()
f_2 <- vary(id = "")
f_3 <- vary(id = "2")
f_4 <- vary("soup")
f_5 <- call2("lm", formula = Species ~ ., .ns = "stats")
f_6 <- call2("lm", formula = Species ~ ., model = vary(), .ns = "stats")
vary_id(f_1)
find_vary_id(f_1)
vary_id(f_2)
find_vary_id(f_2)
vary_id(f_3)
find_vary_id(f_3)
vary_id(f_4)
find_vary_id(f_4)
vary_id(f_5)
find_vary_id(f_5)
vary_id(f_6) # checks entire object
find_vary_id(f_6) # looks inside the elements of the object
# ------------------------------------------------------------------------------
rand_forest() %>% vary_args()
rand_forest(mtry = vary()) %>% vary_args()
rand_forest(mtry = vary("rf_mtry")) %>% vary_args()
rand_forest() %>%
set_engine("ranger", sample.fraction = vary()) %>%
vary_args()
bar <- function() 2
rand_forest() %>%
set_engine("ranger", sample.fraction = bar()) %>%
vary_args()
rand_forest(min_n = vary()) %>%
set_engine("ranger", sample.fraction = vary(id = "rf sampling param")) %>%
vary_args()
# should throw error for only 1 varying per argument
rand_forest() %>%
set_engine("ranger", strata = expr(Class), sampsize = c(vary(), vary())) %>%
vary_args()
model_ex <-
boost_tree(
mtry = vary(), trees = vary(), min_n = vary(), tree_depth = vary(),
sample_size = vary()
) %>%
set_engine("xgboost") %>%
vary_args() %>%
mutate(
arg = map(name, ~ call2(.x, .ns = "dials")),
arg = map(arg, eval_tidy),
arg = map(arg, finalize, x = lending_club)
)
set.seed(462)
grid_latin_hypercube(!!!model_ex$arg, size = 20) %>%
setNames(model_ex$id)
# ------------------------------------------------------------------------------
recipe(Species ~ ., data = iris) %>%
step_bs(Sepal.Width, deg_free = 2, degree = vary("width degree")) %>%
step_bs(Sepal.Length, deg_free = 2, degree = vary("length degree")) %>%
step_pca(all_predictors(), num_comp = vary()) %>%
vary_args()
rec_example <-
recipe(Sepal.Length ~ ., data = iris) %>%
step_bs(Sepal.Width, deg_free = vary(), degree = vary("Sepal Width spline degree")) %>%
step_corr(Species, threshold = vary("corr threshold")) %>%
step_isomap(all_predictors(), num_terms = vary(), neighbors = vary("isomap K")) %>%
vary_args()
rec_example
rec_example <-
rec_example %>%
mutate(
arg = map(name, ~ call2(.x, .ns = "dials")),
arg = map(arg, eval_tidy),
arg = map(arg, finalize, x = iris)
)
set.seed(462)
grid_max_entropy(!!!rec_example$arg, size = 2^length(rec_example$arg)) %>%
setNames(rec_example$id)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment