Skip to content

Commit

Permalink
support window escaping and window tests (tidyverse#1861)
Browse files Browse the repository at this point in the history
  • Loading branch information
javierluraschi authored and hadley committed May 27, 2016
1 parent 4b27fda commit d2a34e7
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 2 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ Collate:
'src-postgres.r'
'src-sql.r'
'src-sqlite.r'
'src-test.r'
'src.r'
'tally.R'
'tbl-cube.r'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ S3method(db_insert_into,PostgreSQLConnection)
S3method(db_insert_into,SQLiteConnection)
S3method(db_list_tables,DBIConnection)
S3method(db_query_fields,DBIConnection)
S3method(db_query_fields,DBITestConnection)
S3method(db_query_fields,PostgreSQLConnection)
S3method(db_query_rows,DBIConnection)
S3method(db_rollback,DBIConnection)
Expand Down Expand Up @@ -224,6 +225,7 @@ S3method(sql_build,op_summarise)
S3method(sql_build,op_ungroup)
S3method(sql_build,tbl_lazy)
S3method(sql_build,tbl_sql)
S3method(sql_escape_ident,DBITestConnection)
S3method(sql_escape_ident,MySQLConnection)
S3method(sql_escape_ident,SQLiteConnection)
S3method(sql_escape_ident,default)
Expand All @@ -244,6 +246,7 @@ S3method(sql_set_op,default)
S3method(sql_subquery,SQLiteConnection)
S3method(sql_subquery,default)
S3method(sql_translate_env,"NULL")
S3method(sql_translate_env,DBITestConnection)
S3method(sql_translate_env,MySQLConnection)
S3method(sql_translate_env,PostgreSQLConnection)
S3method(sql_translate_env,SQLiteConnection)
Expand Down
2 changes: 1 addition & 1 deletion R/over.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ over <- function(expr, partition = NULL, order = NULL, frame = NULL) {
}

partition <- build_sql("PARTITION BY ",
sql_vector(escape(partition), collapse = ", ", parens = FALSE))
sql_vector(escape(partition, con = partition_con()), collapse = ", ", parens = FALSE))
}
if (!is.null(order)) {
if (!is.sql(order)) {
Expand Down
29 changes: 29 additions & 0 deletions R/src-test.r
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#' A set of DBI methods to ease unit testing dplyr with DBI
#' @name src-test
#' @param con A database connection.
#' @param x Object to transform
#' @param sql A string containing an sql query.
#' @param ... Other arguments passed on to the individual methods
NULL

#' @export
#' @rdname src-test
db_query_fields.DBITestConnection <- function(con, sql, ...) {
c("field1")
}

#' @export
#' @rdname src-test
sql_escape_ident.DBITestConnection <- function(con, x) {
sql_quote(x, '`')
}

#' @export
#' @rdname src-test
sql_translate_env.DBITestConnection <- function(con) {
dplyr::sql_variant(
scalar = dplyr::sql_translator(.parent = dplyr::base_scalar),
aggregate = dplyr::sql_translator(.parent = dplyr::base_agg),
window = dplyr::sql_translator(.parent = dplyr::base_win)
)
}
11 changes: 10 additions & 1 deletion R/translate-sql-helpers.r
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ win_absent <- function(f) {
partition <- new.env(parent = emptyenv())
partition$group_by <- NULL
partition$order_by <- NULL
partition$con <- NULL

set_partition_con <- function(con) {
old <- partition$con
partition$con <- con
invisible(old)
}

set_partition_group <- function(vars) {
stopifnot(is.null(vars) || is.character(vars))
Expand All @@ -194,7 +201,7 @@ set_partition_order <- function(vars) {
invisible(old)
}

set_partition <- function(group_by, order_by) {
set_partition <- function(group_by, order_by, con = NULL) {
old <- list(partition$group_by, partition$order_by)
if (is.list(group_by)) {
order_by <- group_by[[2]]
Expand All @@ -203,9 +210,11 @@ set_partition <- function(group_by, order_by) {

partition$group_by <- group_by
partition$order_by <- order_by
partition$con <- con

invisible(old)
}

partition_group <- function() partition$group_by
partition_order <- function() partition$order_by
partition_con <- function() partition$con
3 changes: 3 additions & 0 deletions R/translate-sql.r
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ translate_sql_ <- function(dots,
variant <- sql_translate_env(con)

if (window) {
old_con <- set_partition_con(con)
on.exit(set_partition_con(old_con), add = TRUE)

old_group <- set_partition_group(vars_group)
on.exit(set_partition_group(old_group), add = TRUE)

Expand Down
28 changes: 28 additions & 0 deletions man/src-test.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions tests/testthat/test-sql-translation.r
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,11 @@ test_that("ntile always casts to integer", {
)
})

test_that("connection affects quoting character", {
dbiTest <- structure(list(), class = "DBITestConnection")
dbTest <- src_sql("test", con = dbiTest)
testTable <- tbl_sql("test", src = dbTest, from = "table1")

out <- select(testTable, field1)
expect_match(sql_render(out), "^SELECT `field1` AS `field1`\nFROM `table1`$")
})
21 changes: 21 additions & 0 deletions tests/testthat/test-window.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,24 @@ test_that("multiple group by or order values don't have parens", {
sql('"x" OVER (PARTITION BY "x", "y")')
)
})

test_that("connection affects quoting window function fields", {
dbiTest <- structure(list(), class = "DBITestConnection")
dbTest <- src_sql("test", con = dbiTest)
testTable <- tbl_sql("test", src = dbTest, from = "table1")

out <- filter(group_by(testTable, field1), min_rank(desc(field1)) < 2)
sqlText <- sql_render(out)

testthat::expect_equal(
grep(paste(
"^SELECT `field1`",
"FROM \\(SELECT `field1`, rank\\(\\) OVER \\(PARTITION BY `field1` ORDER BY `field1` DESC\\) AS `[a-zA-Z0-9]+`",
"FROM `table1`\\) `[a-zA-Z0-9]+`",
"WHERE \\(`[a-zA-Z0-9]+` < 2.0\\)$",
sep = "\n"
), sqlText),
1,
info = sqlText
)
})

0 comments on commit d2a34e7

Please sign in to comment.