Created
September 20, 2018 08:22
-
-
Save coolbutuseless/7f86eedceb5c0ce6de8bf2dc6b3a0844 to your computer and use it in GitHub Desktop.
Stricter version of dplyr::case_when()
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
#' Stricter version of case_when() | |
#' - disallows a fall-through 'TRUE' value on the LHS. | |
#' - disallows input values which do not match any rules. | |
#' - disallows input values which match more than one rule | |
#' | |
#' @param ... arguments to case_when | |
#' | |
#' @return A vector of length 1 or n, matching the length of the logical input | |
#' or output vectors, with the type (and attributes) of the first RHS. | |
#' Inconsistent lengths or types will generate an error. | |
#' | |
#' @import dplyr | |
#' @import rlang | |
#' @import purrr | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
strict_case_when <- function(...) { | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# Allow case_when to do its thing! | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
res <- dplyr::case_when(...) | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# If `case_when` runs ok, then it means I can make the assumption that all | |
# its input arguments are well-formed formulas, with a proper LHS | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
formulas <- rlang::list2(...) | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# Check for fall-through 'TRUE' value | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
lhs <- formulas %>% purrr::map(rlang::f_lhs) | |
if (any(purrr::map_lgl(lhs, isTRUE))) { | |
stop("strict_case_when(): fall-through 'TRUE' is not allowed", call. = FALSE) | |
} | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# Count how many times each input is matched by a rule. | |
# Have to evaluate this in the right environment. | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
match_counts <- formulas %>% | |
purrr::map(~rlang::eval_bare(rlang::f_lhs(.x), env = environment(.x))) %>% | |
purrr::transpose() %>% | |
purrr::map(flatten_lgl) %>% | |
purrr::map_int(sum, na.rm=TRUE) | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# Check if an input values unmatched | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
if (any(match_counts == 0L)) { | |
stop("strict_case_when(): no matches found at the following input indices: ", | |
deparse(which(match_counts == 0L)), call. = FALSE) | |
} | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# Check if any input values matched multiple times | |
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
if (any(match_counts > 1L)) { | |
stop("strict_case_when(): multiple matches found at the following input indices: ", | |
deparse(which(match_counts > 1L)), call. = FALSE) | |
} | |
res | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment