Skip to content

Commit

Permalink
Clarify translate_window_where interface
Browse files Browse the repository at this point in the history
  • Loading branch information
hadley committed Mar 17, 2016
1 parent 7ad329e commit 0e6005a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 29 deletions.
2 changes: 1 addition & 1 deletion R/sql-build.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ sql_build.op_filter <- function(op, con, ...) {
} else {
# Do partial evaluation, then extract out window functions
expr <- partial_eval2(op$dots, vars)
where <- translate_window_where(expr, ls(sql_translate_env(con)$window))
where <- translate_window_where_all(expr, ls(sql_translate_env(con)$window))

# Convert where$expr back to a lazy dots object, and then
# create mutate operation
Expand Down
64 changes: 36 additions & 28 deletions R/translate-sql-window.r
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ uses_window_fun <- function(x, con) {
any(calls %in% win_f)
}

common_window_funs <- ls(sql_translate_env(NULL)$window)

#' @noRd
#' @examples
#' translate_window_where(quote(1))
Expand All @@ -18,41 +20,47 @@ uses_window_fun <- function(x, con) {
#' translate_window_where(quote(x == 1 && y == 2))
#' translate_window_where(quote(n() > 10))
#' translate_window_where(quote(rank() > cumsum(AB)))
#' translate_window_where(list(quote(x == 1), quote(n() > 2)))
#' translate_window_where(list(quote(cumsum(x) == 10), quote(n() > 2)))
translate_window_where <- function(expr, window_funs = ls(sql_translate_env(NULL)$window)) {
# Simplest base case: atomic vector or name ---------------------------------
translate_window_where <- function(expr, window_funs = common_window_funs) {
if (is.atomic(expr) || is.name(expr)) {
return(list(
expr = expr,
comp = list()
))
}
window_where(expr, list())
} else if (is.call(expr)) {
if (as.character(expr[[1]]) %in% window_funs) {
name <- unique_name()
window_where(as.name(name), setNames(list(expr), name))
} else {
args <- lapply(expr[-1], translate_window_where, window_funs = window_funs)
expr <- as.call(c(expr[[1]], lapply(args, "[[", "expr")))

# Other base case is an aggregation function --------------------------------
if (is.call(expr) && as.character(expr[[1]]) %in% window_funs) {
name <- unique_name()

return(list(
expr = as.name(name),
comp = setNames(list(expr), name)
))
window_where(
expr = expr,
comp = unlist(lapply(args, "[[", "comp"), recursive = FALSE)
)
}
} else {
stop("Unknown type: ", typeof(expr))
}
}

# Recursive cases: list and all other functions -----------------------------

if (is.list(expr)) {
args <- lapply(expr, translate_window_where, window_funs = window_funs)
call <- unlist(lapply(args, "[[", "expr"), recursive = FALSE)
} else {
args <- lapply(expr[-1], translate_window_where, window_funs = window_funs)
call <- list(as.call(c(expr[[1]], lapply(args, "[[", "expr"))))
}
#' @noRd
#' @examples
#' translate_window_where_all(list(quote(x == 1), quote(n() > 2)))
#' translate_window_where_all(list(quote(cumsum(x) == 10), quote(n() > 2)))
translate_window_where_all <- function(x, window_funs = common_window_funs) {
out <- lapply(x, translate_window_where, window_funs = window_funs)

list(
expr = unlist(lapply(out, "[[", "expr"), recursive = FALSE),
comp = unlist(lapply(out, "[[", "comp"), recursive = FALSE)
)
}

comps <- unlist(lapply(args, "[[", "comp"), recursive = FALSE)
window_where <- function(expr, comp) {
stopifnot(is.call(expr) || is.name(expr) || is.atomic(expr))
stopifnot(is.list(comp))

list(
expr = call,
comp = comps
expr = expr,
comp = comp
)
}

0 comments on commit 0e6005a

Please sign in to comment.