Skip to content

Commit

Permalink
Rework translate_sql interface
Browse files Browse the repository at this point in the history
* Use lazydots
* Instead of tbl, pass con & vars
* Default to window = TRUE
* src_translate_env -> sql_translate_env
  • Loading branch information
hadley committed Mar 16, 2016
1 parent 8cc66df commit 708187c
Show file tree
Hide file tree
Showing 17 changed files with 262 additions and 213 deletions.
13 changes: 6 additions & 7 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,15 @@ S3method(sql_semi_join,DBIConnection)
S3method(sql_set_op,DBIConnection)
S3method(sql_subquery,"NULL")
S3method(sql_subquery,DBIConnection)
S3method(sql_translate_env,"NULL")
S3method(sql_translate_env,MySQLConnection)
S3method(sql_translate_env,PostgreSQLConnection)
S3method(sql_translate_env,SQLiteConnection)
S3method(src_desc,src_mysql)
S3method(src_desc,src_postgres)
S3method(src_desc,src_sqlite)
S3method(src_tbls,src_local)
S3method(src_tbls,src_sql)
S3method(src_translate_env,"NULL")
S3method(src_translate_env,src_mysql)
S3method(src_translate_env,src_postgres)
S3method(src_translate_env,src_sqlite)
S3method(src_translate_env,tbl_sql)
S3method(summarise_,data.frame)
S3method(summarise_,tbl_cube)
S3method(summarise_,tbl_df)
Expand Down Expand Up @@ -471,6 +470,7 @@ export(sql_select)
export(sql_semi_join)
export(sql_set_op)
export(sql_subquery)
export(sql_translate_env)
export(sql_translator)
export(sql_variant)
export(src)
Expand All @@ -483,7 +483,6 @@ export(src_postgres)
export(src_sql)
export(src_sqlite)
export(src_tbls)
export(src_translate_env)
export(starts_with)
export(summarise)
export(summarise_)
Expand All @@ -506,7 +505,7 @@ export(test_register_src)
export(tibble)
export(top_n)
export(translate_sql)
export(translate_sql_q)
export(translate_sql_)
export(transmute)
export(transmute_)
export(trunc_mat)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# dplyr 0.4.3.9000

* `src_translate_env()` has been replaced by `sql_translate_env()` which
should have methods for the connection object.

* New `src_memdb()` which is a session-local in-memory SQLite db.

* `distinct()` now only keeps the distinct variables. If you want to return
Expand Down
14 changes: 13 additions & 1 deletion R/dbi-s3.r
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,17 @@ src_desc <- function(x) UseMethod("src_desc")

#' @name backend_src
#' @export
src_translate_env <- function(x) UseMethod("src_translate_env")
sql_translate_env <- function(con) UseMethod("sql_translate_env")

#' @name backend_src
#' @export
sql_translate_env.NULL <- function(con) {
sql_variant(
base_scalar,
base_agg,
base_win
)
}

#' Database generics.
#'
Expand Down Expand Up @@ -222,6 +231,9 @@ sql_select.DBIConnection <- function(con, select, from, where = NULL,
names(out) <- c("select", "from", "where", "group_by", "having", "order_by",
"limit", "offset")

if (length(select) == 0) {
select <- "*"
}
assert_that(is.character(select), length(select) > 0L)
out$select <- build_sql("SELECT ", escape(select, collapse = ", ", con = con))

Expand Down
37 changes: 37 additions & 0 deletions R/partial-eval.r
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,40 @@ partial_eval <- function(call, tbl = NULL, env = parent.frame()) {
}
}


partial_eval2 <- function(call, vars = character(), env = parent.frame()) {
if (is.atomic(call)) return(call)

if (inherits(call, "lazy_dots")) {
lapply(call, function(l) partial_eval2(l$expr, vars, l$env))
} else if (is.list(call)) {
lapply(call, partial_eval2, vars, env = env)
} else if (is.symbol(call)) {
name <- as.character(call)
if (name %in% vars) {
call
} else if (exists(name, env)) {
eval(call, env)
} else {
call
}
} else if (is.call(call)) {
# Process call arguments recursively, unless user has manually called
# remote/local
name <- as.character(call[[1]])
if (name == "local") {
eval(call[[2]], env)
} else if (name %in% c("$", "[[", "[")) {
# Subsetting is always done locally
eval(call, env)
} else if (name == "remote") {
call[[2]]
} else {
call[-1] <- lapply(call[-1], partial_eval2, vars = vars, env = env)
call
}
} else {
stop("Unknown input type: ", class(call), call. = FALSE)
}
}

4 changes: 2 additions & 2 deletions R/sql-build.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ sql_build.op_rename <- function(op, ...) {
select_query(sql_build(op$x, con), vars)
}


# select_query ------------------------------------------------------------

#' @export
#' @rdname sql_build
select_query <- function(from, select,
select_query <- function(from,
select = character(),
where = character(),
group_by = character(),
having = character(),
Expand Down
2 changes: 1 addition & 1 deletion R/src-mysql.r
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ src_desc.src_mysql <- function(x) {
}

#' @export
src_translate_env.src_mysql <- function(x) {
sql_translate_env.MySQLConnection <- function(x) {
sql_variant(
base_scalar,
sql_translator(.parent = base_agg,
Expand Down
2 changes: 1 addition & 1 deletion R/src-postgres.r
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ src_desc.src_postgres <- function(x) {
}

#' @export
src_translate_env.src_postgres <- function(x) {
sql_translate_env.PostgreSQLConnection <- function(x) {
sql_variant(
base_scalar,
sql_translator(.parent = base_agg,
Expand Down
2 changes: 1 addition & 1 deletion R/src-sqlite.r
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ src_desc.src_sqlite <- function(x) {
}

#' @export
src_translate_env.src_sqlite <- function(x) {
sql_translate_env.SQLiteConnection <- function(x) {
sql_variant(
sql_translator(.parent = base_scalar,
log = sql_prefix("log")
Expand Down
16 changes: 8 additions & 8 deletions R/tbl-sql.r
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ setdiff.tbl_sql <- function(x, y, copy = FALSE, ...) {
build_query <- function(x, limit = NULL) {
assert_that(is.null(limit) || (is.numeric(limit) && length(limit) == 1))
translate <- function(expr, ...) {
translate_sql_q(expr, tbl = x, env = NULL, ...)
translate_sql(expr, vars = names(x), ...)
}

if (x$summarise) {
Expand All @@ -217,7 +217,7 @@ build_query <- function(x, limit = NULL) {
order_by_sql <- translate(x$order_by)
} else {
# Not in summarise, so assume functions are window functions
select_sql <- translate(x$select, window = uses_window_fun(x$select, x))
select_sql <- translate(x$select, window = uses_window_fun(x$select, x$con))
vars <- auto_names(x$select)

# Don't use group_by - grouping affects window functions only
Expand All @@ -232,7 +232,7 @@ build_query <- function(x, limit = NULL) {
}
}

if (!uses_window_fun(x$where, x)) {
if (!uses_window_fun(x$where, x$con)) {
from_sql <- x$from
where_sql <- translate(x$where)
} else {
Expand All @@ -254,15 +254,15 @@ build_query <- function(x, limit = NULL) {
query(x$src$con, sql, vars)
}

uses_window_fun <- function(x, tbl) {
uses_window_fun <- function(x, con) {
if (is.null(x)) return(FALSE)
if (is.list(x)) {
calls <- unlist(lapply(x, all_calls))
} else {
calls <- all_calls(x)
}

win_f <- ls(envir = src_translate_env(tbl)$window)
win_f <- ls(envir = sql_translate_env(con)$window)
any(calls %in% win_f)
}

Expand Down Expand Up @@ -350,7 +350,7 @@ mutate_.tbl_sql <- function(.data, ..., .dots) {
# If we're creating a variable that uses a window function, it's
# safest to turn that into a subquery so that filter etc can use
# the new variable name
if (uses_window_fun(input, .data)) {
if (uses_window_fun(input, .data$con)) {
collapse(new)
} else {
new
Expand All @@ -367,8 +367,8 @@ group_by_.tbl_sql <- function(.data, ..., .dots, add = FALSE) {
# * filter: changes frame of window functions
# * mutate: changes frame of window functions
# * arrange: if present, groups inserted as first ordering
needed <- (x$mutate && uses_window_fun(x$select, x)) ||
uses_window_fun(x$filter, x)
needed <- (x$mutate && uses_window_fun(x$select, x$con)) ||
uses_window_fun(x$filter, x$con)
if (!is.null(x$order_by)) {
arrange <- c(x$group_by, x$order_by)
} else {
Expand Down
18 changes: 18 additions & 0 deletions R/translate-sql-helpers.r
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,24 @@ win_absent <- function(f) {
# tbl and sql translator. This isn't the most amazing design, but it keeps
# things loosely coupled and is straightforward to understand.
partition <- new.env(parent = emptyenv())
partition$group_by <- NULL
partition$order_by <- NULL

set_partition_group <- function(vars) {
stopifnot(is.null(vars) || is.character(vars))

old <- partition$group_by
partition$group_by <- vars
invisible(old)
}

set_partition_order <- function(vars) {
stopifnot(is.null(vars) || is.character(vars))

old <- partition$order_by
partition$order_by <- vars
invisible(old)
}

set_partition <- function(group_by, order_by) {
old <- list(partition$group_by, partition$order_by)
Expand Down
10 changes: 5 additions & 5 deletions R/translate-sql-window.r
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# translate_window_where(quote(n() > 10), players)
# translate_window_where(quote(rank() > cumsum(AB)), players)
# translate_window_where(list(quote(x == 1), quote(n() > 2)), players)
translate_window_where <- function(expr, tbl, con = NULL) {
translate_window_where <- function(expr, vars, con = NULL) {
# Simplest base case: atomic vector or name ---------------------------------
if (is.atomic(expr) || is.name(expr)) {
return(list(
Expand All @@ -17,12 +17,12 @@ translate_window_where <- function(expr, tbl, con = NULL) {
}

# Other base case is an aggregation function --------------------------------
variant <- src_translate_env(tbl)
variant <- sql_translate_env(con)
agg_f <- ls(envir = variant$window)

if (is.call(expr) && as.character(expr[[1]]) %in% agg_f) {
name <- unique_name()
sql <- translate_sql_q(list(expr), tbl, env = NULL, window = TRUE)
sql <- translate_sql_(list(expr), vars = vars)

return(list(
expr = as.name(name),
Expand All @@ -33,12 +33,12 @@ translate_window_where <- function(expr, tbl, con = NULL) {
# Recursive cases: list and all other functions -----------------------------

if (is.list(expr)) {
args <- lapply(expr, translate_window_where, tbl = tbl, con = con)
args <- lapply(expr, translate_window_where, vars = vars, con = con)

env <- sql_env(call, variant, con = con)
sql <- lapply(lapply(args, "[[", "expr"), eval, env = env)
} else {
args <- lapply(expr[-1], translate_window_where, tbl = tbl, con = con)
args <- lapply(expr[-1], translate_window_where, vars = vars, con = con)

call <- as.call(c(expr[[1]], lapply(args, "[[", "expr")))
env <- sql_env(call, variant, con = con)
Expand Down
Loading

0 comments on commit 708187c

Please sign in to comment.