Skip to content

Instantly share code, notes, and snippets.

@coolbutuseless
Created September 20, 2018 08:22
Show Gist options
  • Save coolbutuseless/7f86eedceb5c0ce6de8bf2dc6b3a0844 to your computer and use it in GitHub Desktop.
Save coolbutuseless/7f86eedceb5c0ce6de8bf2dc6b3a0844 to your computer and use it in GitHub Desktop.
Stricter version of dplyr::case_when()
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#' Stricter version of case_when()
#' - disallows a fall-through 'TRUE' value on the LHS.
#' - disallows input values which do not match any rules.
#' - disallows input values which match more than one rule
#'
#' @param ... arguments to case_when
#'
#' @return A vector of length 1 or n, matching the length of the logical input
#' or output vectors, with the type (and attributes) of the first RHS.
#' Inconsistent lengths or types will generate an error.
#'
#' @import dplyr
#' @import rlang
#' @import purrr
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when <- function(...) {
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Allow case_when to do its thing!
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
res <- dplyr::case_when(...)
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# If `case_when` runs ok, then it means I can make the assumption that all
# its input arguments are well-formed formulas, with a proper LHS
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
formulas <- rlang::list2(...)
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Check for fall-through 'TRUE' value
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
lhs <- formulas %>% purrr::map(rlang::f_lhs)
if (any(purrr::map_lgl(lhs, isTRUE))) {
stop("strict_case_when(): fall-through 'TRUE' is not allowed", call. = FALSE)
}
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Count how many times each input is matched by a rule.
# Have to evaluate this in the right environment.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
match_counts <- formulas %>%
purrr::map(~rlang::eval_bare(rlang::f_lhs(.x), env = environment(.x))) %>%
purrr::transpose() %>%
purrr::map(flatten_lgl) %>%
purrr::map_int(sum, na.rm=TRUE)
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Check if an input values unmatched
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if (any(match_counts == 0L)) {
stop("strict_case_when(): no matches found at the following input indices: ",
deparse(which(match_counts == 0L)), call. = FALSE)
}
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Check if any input values matched multiple times
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if (any(match_counts > 1L)) {
stop("strict_case_when(): multiple matches found at the following input indices: ",
deparse(which(match_counts > 1L)), call. = FALSE)
}
res
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment