Skip to content

Instantly share code, notes, and snippets.

@DavisVaughan
Last active October 14, 2021 18:39
Show Gist options
  • Select an option

  • Save DavisVaughan/282ea2cfb88e0938bf0b655014d22c55 to your computer and use it in GitHub Desktop.

Select an option

Save DavisVaughan/282ea2cfb88e0938bf0b655014d22c55 to your computer and use it in GitHub Desktop.
df <- tibble(
g = c(1, 1, 1, 1, 2, 2, 2, 2),
x = 1:8,
y = 8:1
)
# Group to show that this is applied per groups nicely
df <- group_by(df, g)
x <- df$x
x
#> [1] 1 2 3 4 5 6 7 8
df
#> # A tibble: 8 × 3
#> # Groups: g [2]
#> g x y
#> <dbl> <int> <int>
#> 1 1 1 8
#> 2 1 2 7
#> 3 1 3 6
#> 4 1 4 5
#> 5 2 5 4
#> 6 2 6 3
#> 7 2 7 2
#> 8 2 8 1
# - 1 new col
# - 1 or N cases
# - Default = user supplied
case_when(
x > 6, 6,
x > 3, 3,
default = 1
)
#> [1] 1 1 1 3 3 3 6 6
# Can use vectors as replacements which get sliced to the size of sum(condition)
df %>%
mutate(
z = case_when(
x > 6, mean(y),
x > 3, y,
default = 1
)
)
#> # A tibble: 8 × 4
#> # Groups: g [2]
#> g x y z
#> <dbl> <int> <int> <dbl>
#> 1 1 1 8 1
#> 2 1 2 7 1
#> 3 1 3 6 1
#> 4 1 4 5 5
#> 5 2 5 4 4
#> 6 2 6 3 3
#> 7 2 7 2 2.5
#> 8 2 8 1 2.5
# - N new cols
# - 1 or N cases
# - Default = user supplied
case_when(
x > 6, tibble(x = 1, y = 2),
x > 3, tibble(x = 3, y = 6),
default = tibble(x = NA, y = NA)
)
#> # A tibble: 8 × 2
#> x y
#> <dbl> <dbl>
#> 1 NA NA
#> 2 NA NA
#> 3 NA NA
#> 4 3 6
#> 5 3 6
#> 6 3 6
#> 7 1 2
#> 8 1 2
# Data frame auto expansion for the win!
# (CRAN dplyr can't do this)
df %>%
mutate(
case_when(
x > 6, tibble(a = 1, b = mean(y)),
x > 3, tibble(a = 3, b = y),
default = tibble(a = NA, b = NA)
)
)
#> # A tibble: 8 × 5
#> # Groups: g [2]
#> g x y a b
#> <dbl> <int> <int> <dbl> <dbl>
#> 1 1 1 8 NA NA
#> 2 1 2 7 NA NA
#> 3 1 3 6 NA NA
#> 4 1 4 5 3 5
#> 5 2 5 4 3 4
#> 6 2 6 3 3 3
#> 7 2 7 2 1 2.5
#> 8 2 8 1 1 2.5
# - 1 existing col
# - 1 or N cases
# - Default = original col
replace_when(
x,
x > 6, 6,
x > 3, 3
)
#> [1] 1 2 3 3 3 3 6 6
df %>%
mutate(
x = replace_when(
x,
x > 6, max(y),
x > 3, y
)
)
#> # A tibble: 8 × 3
#> # Groups: g [2]
#> g x y
#> <dbl> <int> <int>
#> 1 1 1 8
#> 2 1 2 7
#> 3 1 3 6
#> 4 1 5 5
#> 5 2 4 4
#> 6 2 3 3
#> 7 2 4 2
#> 8 2 4 1
# Special consideration for replace()
replace_when(x, x > 6, NA)
#> [1] 1 2 3 4 5 6 NA NA
replace(x, x > 6, NA)
#> [1] 1 2 3 4 5 6 NA NA
# Which fails with this common pattern
y <- df$y
replace_when(x, x > 6, y)
#> [1] 1 2 3 4 5 6 2 1
replace(x, x > 6, y)
#> Warning in x[list] <- values: number of items to replace is not a multiple of replacement length
#> [1] 1 2 3 4 5 6 8 7
# - N existing cols
# - 1 or N cases
# - Default = original col
replace_when(
tibble(x = x, y = y),
x > 6, tibble(x = NA, y = NA),
x > 3, tibble(x = max(x), y = max(y))
)
#> # A tibble: 8 × 2
#> x y
#> <int> <int>
#> 1 1 8
#> 2 2 7
#> 3 3 6
#> 4 8 8
#> 5 8 8
#> 6 8 8
#> 7 NA NA
#> 8 NA NA
# While possible, this is pretty clunky and also is pretty rare
df %>%
mutate(
replace_when(
tibble(x = x, y = y),
x > 6, tibble(x = NA, y = NA),
x > 3, tibble(x = max(x), y = max(y))
)
)
#> # A tibble: 8 × 3
#> # Groups: g [2]
#> g x y
#> <dbl> <int> <int>
#> 1 1 1 8
#> 2 1 2 7
#> 3 1 3 6
#> 4 1 4 8
#> 5 2 8 4
#> 6 2 8 4
#> 7 2 NA NA
#> 8 2 NA NA
# More common would be to update multiple columns based on 1 condition
# (Can't use if_else() because we want type stability of `x`)
df %>%
mutate(
replace_when(
tibble(x = x, y = y),
x > 6, tibble(x = NA, y = NA)
)
)
#> # A tibble: 8 × 3
#> # Groups: g [2]
#> g x y
#> <dbl> <int> <int>
#> 1 1 1 8
#> 2 1 2 7
#> 3 1 3 6
#> 4 1 4 5
#> 5 2 5 4
#> 6 2 6 3
#> 7 2 NA NA
#> 8 2 NA NA
# But this case seems so common that we should provide a native dplyr helper
# (this seems like the only way to avoid specifying the columns twice)
revise(df, x > 6, x = NA, y = NA)
#> # A tibble: 8 × 3
#> # Groups: g [2]
#> g x y
#> <dbl> <int> <int>
#> 1 1 1 8
#> 2 1 2 7
#> 3 1 3 6
#> 4 1 4 5
#> 5 2 5 4
#> 6 2 6 3
#> 7 2 NA NA
#> 8 2 NA NA
# Note that revise() computes `...` on the filtered data, while replace_when()
# uses the entire group's data.
# I'm not sure if revise() is correct or not (it is what data table does but
# feels kind of wrong for these computed column cases).
revise(df, x > 6, y = max(y))
#> # A tibble: 8 × 3
#> # Groups: g [2]
#> g x y
#> <dbl> <int> <int>
#> 1 1 1 8
#> 2 1 2 7
#> 3 1 3 6
#> 4 1 4 5
#> 5 2 5 4
#> 6 2 6 3
#> 7 2 7 2
#> 8 2 8 2
mutate(df, y = replace_when(y, x > 6, max(y)))
#> # A tibble: 8 × 3
#> # Groups: g [2]
#> g x y
#> <dbl> <int> <int>
#> 1 1 1 8
#> 2 1 2 7
#> 3 1 3 6
#> 4 1 4 5
#> 5 2 5 4
#> 6 2 6 3
#> 7 2 7 4
#> 8 2 8 4
@DavisVaughan

DavisVaughan commented Oct 14, 2021

Copy link
Copy Markdown
Author

Implementation:

library(rlang)
library(vctrs)
library(glue)
devtools::load_all() # in dplyr project
#> ℹ Loading dplyr

case_when <- function(...,
                      default = NULL,
                      ptype = NULL,
                      size = NULL) {
  args <- list2(...)
  args <- unname(args)
  
  n_args <- length(args)
  
  if (n_args == 0L) {
    abort("No cases provided.")
  }
  if ((n_args %% 2) != 0L) {
    abort("`...` must be an even number of inputs.")
  }
  
  n_wheres <- n_args / 2L
  n_values <- n_wheres + 1L
  
  loc_wheres <- seq.int(1L, n_args - 1L, by = 2)
  loc_values <- loc_wheres + 1L
  
  wheres <- args[loc_wheres]
  values <- args[loc_values]
  
  for (i in seq_len(n_wheres)) {
    where <- wheres[[i]]
    
    if (!is.logical(where)) {
      abort("Each 'where' input in `...` must evaluate to a logical vector.")
    }
    if (anyNA(where)) {
      # `NA` in `where` is skipped
      where[is.na(where)] <- FALSE
    }
    
    wheres[[i]] <- where
  }
  
  size <- vec_size_common(!!!wheres, .size = size)
  
  if (any(list_sizes(wheres) != size)) {
    abort(glue("All 'where' inputs must be size {size}."))
  }
  
  locs <- vector("list", n_values)
  unused <- rep(TRUE, times = size)
  
  for (i in seq_len(n_wheres)) {
    where <- wheres[[i]]
    
    loc <- unused & where
    loc <- vec_as_location(loc, n = size)
    locs[[i]] <- loc
    
    unused[where] <- FALSE
  }
  
  if (is.null(default)) {
    # Must use `unspecified(1)` rather than `NA`.
    # Slicing might result in `logical(0)` which vctrs won't recognize
    # as unspecified.
    default <- unspecified(1L)
  }
  
  # Append `default` to the end
  locs[[n_values]] <- vec_as_location(unused, n = size)
  values <- c(values, list(default))
  
  for (i in seq_len(n_values)) {
    loc <- locs[[i]]
    
    # Recycle up to size, then slice to insertion size
    value <- values[[i]]
    value <- vec_recycle(value, size)
    value <- vec_slice(value, loc)
    values[[i]] <- value
  }
  
  vec_unchop(
    x = values,
    indices = locs,
    ptype = ptype
  )
}

if_else <- function(condition, 
                    true, 
                    false, 
                    ...,
                    missing = NULL,
                    ptype = NULL,
                    size = NULL) {
  check_dots_empty()
  
  case_when(
    condition, true,
    !condition, false,
    default = missing,
    ptype = ptype,
    size = size
  )
}

replace_when <- function(x, ...) {
  case_when(
    ..., 
    default = x, 
    ptype = vec_ptype(x), 
    size = vec_size(x)
  )
}

revise <- function(.data, .condition, ...) {
  UseMethod("revise")
}

revise.data.frame <- function(.data, .condition, ...) {
  caller_env <- caller_env()
  
  size <- vec_size(.data)
  
  loc <- filter_rows(.data, {{ .condition }}, caller_env = caller_env)
  loc <- vec_as_location(loc, n = size)
  
  # What order should this be in? Slice then compute `...`? Or in reverse?
  slice <- dplyr_row_slice(.data, loc)
  revisions <- mutate_cols(slice, ..., caller_env = caller_env)
  
  df_revise(.data, loc, revisions)
}

# Take a data frame and update it at `loc` with `revisions`.
# Similar in spirit to:
# x[loc, names(revisions)] <- revisions
# but built from first principles to only use `dplyr_col_modify()`.
df_revise <- function(x, loc, revisions) {
  if (any(map_lgl(revisions, is.null))) {
    abort("Can't delete columns when using `revise()`.")
  }
  
  names <- names(revisions)
  exists <- names %in% names(x)
  
  cols <- vector("list", length = length(revisions))
  cols <- set_names(cols, names)
  
  # Avoid any subclass `[[` methods
  old <- unclass(x)
  
  for (i in seq_along(cols)) {
    name <- names[[i]]
    revision <- revisions[[i]]
    
    if (exists[[i]]) {
      col <- old[[name]]
    } else {
      col <- vec_init(revision, size)
    }
    
    cols[[i]] <- vec_assign(col, loc, revision, x_arg = name)
  }
  
  dplyr_col_modify(x, cols)
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment