Skip to content

Commit

Permalink
Merge pull request tidyverse#3981
Browse files Browse the repository at this point in the history
  • Loading branch information
lionel- committed Nov 21, 2018
2 parents d3c2365 + c399e98 commit 22d6a33
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 55 deletions.
158 changes: 114 additions & 44 deletions R/case_when.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
#' value of `n` must be consistent across all cases. The case of
#' `n == 0` is treated as a variant of `n != 1`.
#'
#' These dots support [tidy dots][rlang::tidy-dots] features.
#' `NULL` inputs are ignored.
#'
#' These dots support [tidy dots][rlang::list2] features. In
#' particular, if your patterns are stored in a list, you can
#' splice that in with `!!!`.
#' @export
#' @return A vector of length 1 or `n`, matching the length of the logical
#' input or output vectors, with the type (and attributes) of the first
Expand Down Expand Up @@ -79,21 +83,48 @@
#' type = case_when(
#' height > 200 | mass > 200 ~ "large",
#' species == "Droid" ~ "robot",
#' TRUE ~ "other"
#' TRUE ~ "other"
#' )
#' )
#'
#' # Dots support splicing:
#' patterns <- list(
#' x %% 35 == 0 ~ "fizz buzz",
#' x %% 5 == 0 ~ "fizz",
#' x %% 7 == 0 ~ "buzz",
#' TRUE ~ as.character(x)
#' )
#' case_when(!!!patterns)
#'
#' # `case_when()` is not a tidy eval function. If you'd like to reuse
#' # the same patterns, extract the `case_when()` call in a normal
#' # function:
#' case_character_type <- function(height, mass, species) {
#' case_when(
#' height > 200 | mass > 200 ~ "large",
#' species == "Droid" ~ "robot",
#' TRUE ~ "other"
#' )
#' }
#'
#' case_character_type(150, 250, "Droid")
#' case_character_type(150, 150, "Droid")
#'
#' # Such functions can be used inside `mutate()` as well:
#' starwars %>%
#' mutate(type = case_character_type(height, mass, species)) %>%
#' pull(type)
#'
#' # `case_when()` ignores `NULL` inputs. This is useful when you'd
#' # like to use a pattern only under certain conditions. Here we'll
#' # take advantage of the fact that `if` returns `NULL` when there is
#' # no `else` clause:
#' case_character_type <- function(height, mass, species, robots = TRUE) {
#' case_when(
#' height > 200 | mass > 200 ~ "large",
#' if (robots) species == "Droid" ~ "robot",
#' TRUE ~ "other"
#' )
#' }
#'
#' starwars %>%
#' mutate(type = case_character_type(height, mass, species, robots = FALSE)) %>%
#' pull(type)
case_when <- function(...) {
formulas <- list2(...)
n <- length(formulas)
fs <- compact_null(list2(...))
n <- length(fs)

if (n == 0) {
abort("No cases provided")
Expand All @@ -102,44 +133,20 @@ case_when <- function(...) {
query <- vector("list", n)
value <- vector("list", n)

for (i in seq_len(n)) {
f <- formulas[[i]]
if (!inherits(f, "formula") || length(f) != 3) {
non_formula_arg <- substitute(list(...))[[i + 1]]
header <- glue("Case {i} ({deparsed})", deparsed = fmt_obj1(deparse_trunc(non_formula_arg)))
glubort(header, "must be a two-sided formula, not {friendly_type_of(f)}")
}
default_env <- caller_env()
quos_pairs <- map2(fs, seq_along(fs), validate_formula, default_env, current_env())

env <- environment(f)
for (i in seq_len(n)) {
pair <- quos_pairs[[i]]
query[[i]] <- eval_tidy(pair$lhs, env = default_env)
value[[i]] <- eval_tidy(pair$rhs, env = default_env)

query[[i]] <- eval_bare(f[[2]], env)
if (!is.logical(query[[i]])) {
header <- glue("LHS of case {i} ({deparsed})", deparsed = fmt_obj1(deparse_trunc(f_lhs(f))))
glubort(header, "must be a logical, not {friendly_type_of(query[[i]])}")
abort_case_when_logical(pair$lhs, i, query[[i]])
}

value[[i]] <- eval_bare(f[[3]], env)
}

lhs_lengths <- map_int(query, length)
rhs_lengths <- map_int(value, length)
all_lengths <- unique(c(lhs_lengths, rhs_lengths))
if (length(all_lengths) <= 1) {
m <- all_lengths[[1]]
} else {
non_atomic_lengths <- all_lengths[all_lengths != 1]
m <- non_atomic_lengths[[1]]
if (length(non_atomic_lengths) > 1) {
inconsistent_lengths <- non_atomic_lengths[-1]
lhs_problems <- lhs_lengths %in% inconsistent_lengths
rhs_problems <- rhs_lengths %in% inconsistent_lengths
problems <- lhs_problems | rhs_problems
bad_calls(
formulas[problems],
check_length_val(inconsistent_lengths, m, header = NULL, .abort = identity)
)
}
}
m <- validate_case_when_length(query, value, fs)

out <- value[[1]][rep(NA_integer_, m)]
replaced <- rep(FALSE, m)
Expand All @@ -151,3 +158,66 @@ case_when <- function(...) {

out
}

validate_formula <- function(x, i, default_env, dots_env) {
# Formula might be quosured
if (is_quosure(x)) {
default_env <- quo_get_env(x)
x <- quo_get_expr(x)
}

if (!is_formula(x)) {
arg <- substitute(...(), dots_env)[[1]]
abort_case_when_formula(arg, i, x)
}
if (is_null(f_lhs(x))) {
abort("formulas must be two-sided")
}

# Formula might be unevaluated, e.g. if it's been quosured
env <- f_env(x) %||% default_env

list(
lhs = new_quosure(f_lhs(x), env),
rhs = new_quosure(f_rhs(x), env)
)
}

abort_case_when_formula <- function(arg, i, obj) {
deparsed <- fmt_obj1(deparse_trunc(arg))
type <- friendly_type_of(obj)
abort(glue("Case {i} ({deparsed}) must be a two-sided formula, not {type}"))
}

abort_case_when_logical <- function(lhs, i, query) {
deparsed <- fmt_obj1(deparse_trunc(quo_squash(lhs)))
type <- friendly_type_of(query)
abort(glue("LHS of case {i} ({deparsed}) must be a logical vector, not {type}"))
}

validate_case_when_length <- function(query, value, fs) {
lhs_lengths <- map_int(query, length)
rhs_lengths <- map_int(value, length)
all_lengths <- unique(c(lhs_lengths, rhs_lengths))

if (length(all_lengths) <= 1) {
return(all_lengths[[1]])
}

non_atomic_lengths <- all_lengths[all_lengths != 1]
len <- non_atomic_lengths[[1]]

if (length(non_atomic_lengths) == 1) {
return(len)
}

inconsistent_lengths <- non_atomic_lengths[-1]
lhs_problems <- lhs_lengths %in% inconsistent_lengths
rhs_problems <- rhs_lengths %in% inconsistent_lengths
problems <- lhs_problems | rhs_problems

bad_calls(
fs[problems],
check_length_val(inconsistent_lengths, len, header = NULL, .abort = identity)
)
}
4 changes: 4 additions & 0 deletions R/utils.r
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,7 @@ unstructure <- function(x) {
attributes(x) <- NULL
x
}

compact_null <- function(x) {
Filter(function(elt) !is.null(elt), x)
}
51 changes: 41 additions & 10 deletions man/case_when.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

94 changes: 93 additions & 1 deletion tests/testthat/test-case-when.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ test_that("error messages", {
case_when(
50 ~ 1:3
),
"LHS of case 1 (`50`) must be a logical, not a double vector",
"LHS of case 1 (`50`) must be a logical vector, not a double vector",
fixed = TRUE
)
})
Expand Down Expand Up @@ -132,3 +132,95 @@ test_that("case_when can be used in anonymous functions (#3422)", {
pull()
expect_equal(res, c(TRUE, FALSE, FALSE))
})

test_that("case_when() can be used inside mutate()", {
out <- mtcars[1:4, ] %>%
mutate(out = case_when(
cyl == 4 ~ 1,
.data[["am"]] == 1 ~ 2,
TRUE ~ 0
)) %>%
pull()
expect_identical(out, c(2, 2, 1, 0))
})

test_that("can pass quosures to case_when()", {
fs <- local({
x <- 3:1
quos(
x < 2 ~ TRUE,
TRUE ~ FALSE
)
})
expect_identical(case_when(!!!fs), c(FALSE, FALSE, TRUE))
})

test_that("can pass nested quosures to case_when()", {
fs <- local({
foo <- mtcars$cyl[1:4]
quos(
!!quo(foo) == 4 ~ 1,
TRUE ~ 0
)
})
expect_identical(case_when(!!!fs), c(0, 0, 1, 0))
})

test_that("can pass unevaluated formulas to case_when()", {
x <- 6:8
fs <- exprs(
x == 7L ~ TRUE,
TRUE ~ FALSE
)
expect_identical(case_when(!!!fs), c(FALSE, TRUE, FALSE))

out <- local({
x <- 7:9
case_when(!!!fs)
})
expect_identical(out, c(TRUE, FALSE, FALSE))
})

test_that("unevaluated formulas can refer to data mask", {
fs <- exprs(
cyl == 4 ~ 1,
am == 1 ~ 2,
TRUE ~ 0
)
out <- mtcars[1:4, ] %>% mutate(out = case_when(!!!fs)) %>% pull()
expect_identical(out, c(2, 2, 1, 0))
})

test_that("unevaluated formulas can contain quosures", {
quo <- local({
n <- 4
quo(n)
})
fs <- exprs(
cyl == !!quo ~ 1,
am == 1 ~ 2,
TRUE ~ 0
)
out <- mtcars[1:4, ] %>% mutate(out = case_when(!!!fs)) %>% pull()
expect_identical(out, c(2, 2, 1, 0))
})

test_that("NULL inputs are compacted", {
x <- 1:3

bool <- FALSE
out <- case_when(
x == 2 ~ TRUE,
if (bool) x == 3 ~ NA,
TRUE ~ FALSE
)
expect_identical(out, c(FALSE, TRUE, FALSE))

bool <- TRUE
out <- case_when(
x == 2 ~ TRUE,
if (bool) x == 3 ~ NA,
TRUE ~ FALSE
)
expect_identical(out, c(FALSE, TRUE, NA))
})

0 comments on commit 22d6a33

Please sign in to comment.