Skip to content

Commit

Permalink
ARROW-17463: [R] Avoid unnecessary projections (apache#13954)
Browse files Browse the repository at this point in the history
Before:

```
> mtcars |> arrow_table() |> count(cyl) |> explain()
ExecPlan with 6 nodes:
5:SinkNode{}
  4:ProjectNode{projection=[cyl, n]}
    3:ProjectNode{projection=[cyl, n]}
      2:GroupByNode{keys=["cyl"], aggregates=[
      	hash_sum(n, {skip_nulls=true, min_count=1}),
      ]}
        1:ProjectNode{projection=["n": 1, cyl]}
          0:TableSourceNode{}
```

After:

```
ExecPlan with 5 nodes:
4:SinkNode{}
  3:ProjectNode{projection=[cyl, n]}
    2:GroupByNode{keys=["cyl"], aggregates=[
    	hash_sum(n, {skip_nulls=true, min_count=1}),
    ]}
      1:ProjectNode{projection=["n": 1, cyl]}
        0:TableSourceNode{}
```

Authored-by: Neal Richardson <[email protected]>
Signed-off-by: Neal Richardson <[email protected]>
  • Loading branch information
nealrichardson authored Aug 27, 2022
1 parent 1b9c57e commit 80bba29
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 36 deletions.
24 changes: 20 additions & 4 deletions r/R/query-engine.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,14 @@ ExecPlan <- R6Class("ExecPlan",
}
} else {
# If any columns are derived, reordered, or renamed we need to Project
# If there are aggregations, the projection was already handled above
# If there are aggregations, the projection was already handled above.
# We have to project at least once to eliminate some junk columns
# that the ExecPlan adds:
# __fragment_index, __batch_index, __last_in_fragment
# Presumably extraneous repeated projection of the same thing
# (as when we've done collapse() and not projected after) is cheap/no-op
#
# $Project() will check whether we actually need to project, so that
# repeated projection of the same thing
# (as when we've done collapse() and not projected after) is avoided
projection <- c(.data$selected_columns, .data$temp_columns)
node <- node$Project(projection)
if (!is.null(.data$join)) {
Expand Down Expand Up @@ -349,7 +351,11 @@ ExecNode <- R6Class("ExecNode",
Project = function(cols) {
if (length(cols)) {
assert_is_list_of(cols, "Expression")
self$preserve_extras(ExecNode_Project(self, cols, names(cols)))
if (needs_projection(cols, self$schema)) {
self$preserve_extras(ExecNode_Project(self, cols, names(cols)))
} else {
self
}
} else {
self$preserve_extras(ExecNode_Project(self, character(0), character(0)))
}
Expand Down Expand Up @@ -402,3 +408,13 @@ do_exec_plan_substrait <- function(substrait_plan) {
plan <- ExecPlan$create()
ExecPlan_run_substrait(plan, substrait_plan)
}

needs_projection <- function(projection, schema) {
# Check whether `projection` would do anything to data with the given `schema`
field_names <- set_names(map_chr(projection, ~ .$field_name), NULL)

# We need to apply `projection` if:
!all(nzchar(field_names)) || # Any of the Expressions are not FieldRefs
!identical(field_names, names(projection)) || # Any fields are renamed
!identical(field_names, names(schema)) # The fields are reordered
}
36 changes: 36 additions & 0 deletions r/tests/testthat/test-dplyr-collapse.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,39 @@ test_that("query_on_dataset handles collapse()", {
select(int)
))
})

test_that("collapse doesn't unnecessarily add ProjectNodes", {
plan <- capture.output(
tab %>%
collapse() %>%
collapse() %>%
show_query()
)
# There should be no projections
expect_length(grep("ProjectNode", plan), 0)

plan <- capture.output(
tab %>%
select(int, chr) %>%
collapse() %>%
collapse() %>%
show_query()
)
# There should be just one projection
expect_length(grep("ProjectNode", plan), 1)

skip_if_not_available("dataset")
# We need one ProjectNode on dataset queries to handle augmented fields

tf <- tempfile()
write_dataset(tab, tf, partitioning = "lgl")
ds <- open_dataset(tf)

plan <- capture.output(
ds %>%
collapse() %>%
collapse() %>%
show_query()
)
expect_length(grep("ProjectNode", plan), 1)
})
82 changes: 52 additions & 30 deletions r/tests/testthat/test-dplyr-query.R
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,9 @@ test_that("show_exec_plan(), show_query() and explain()", {
arrow_table() %>%
show_exec_plan(),
regexp = paste0(
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"ProjectNode.*", # output columns
"TableSourceNode" # entry point
"ExecPlan with 2 nodes:.*", # boiler plate for ExecPlan
"SinkNode.*", # output
"TableSourceNode" # entry point
)
)

Expand All @@ -463,12 +463,12 @@ test_that("show_exec_plan(), show_query() and explain()", {
mutate(int_plus_ten = int + 10) %>%
show_exec_plan(),
regexp = paste0(
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"chr, int, lgl, \"int_plus_ten\".*", # selected columns
"FilterNode.*", # filter node
"(dbl > 2).*", # filter expressions
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"chr, int, lgl, \"int_plus_ten\".*", # selected columns
"FilterNode.*", # filter node
"(dbl > 2).*", # filter expressions
"chr != \"e\".*",
"TableSourceNode" # entry point
"TableSourceNode" # entry point
)
)

Expand All @@ -481,11 +481,11 @@ test_that("show_exec_plan(), show_query() and explain()", {
mutate(int_plus_ten = int + 10) %>%
show_exec_plan(),
regexp = paste0(
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"chr, int, lgl, \"int_plus_ten\".*", # selected columns
"(dbl > 2).*", # the filter expressions
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"chr, int, lgl, \"int_plus_ten\".*", # selected columns
"(dbl > 2).*", # the filter expressions
"chr != \"e\".*",
"TableSourceNode" # the entry point"
"TableSourceNode" # the entry point"
)
)

Expand All @@ -497,13 +497,13 @@ test_that("show_exec_plan(), show_query() and explain()", {
summarise(avg = mean(dbl, na.rm = TRUE)) %>%
show_exec_plan(),
regexp = paste0(
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"ProjectNode.*", # output columns
"GroupByNode.*", # the group_by statement
"keys=.*lgl.*", # the key for the aggregations
"aggregates=.*hash_mean.*avg.*", # the aggregations
"ProjectNode.*", # the input columns
"TableSourceNode" # the entry point
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"ProjectNode.*", # output columns
"GroupByNode.*", # the group_by statement
"keys=.*lgl.*", # the key for the aggregations
"aggregates=.*hash_mean.*avg.*", # the aggregations
"ProjectNode.*", # the input columns
"TableSourceNode" # the entry point
)
)

Expand All @@ -521,14 +521,13 @@ test_that("show_exec_plan(), show_query() and explain()", {
select(int, verses, doubled_dbl) %>%
show_exec_plan(),
regexp = paste0(
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"ProjectNode.*", # output columns
"HashJoinNode.*", # the join
"ProjectNode.*", # input columns for the second table
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"ProjectNode.*", # output columns
"HashJoinNode.*", # the join
"ProjectNode.*", # input columns for the second table
"\"doubled_dbl\"\\: multiply_checked\\(dbl, 2\\).*", # mutate
"TableSourceNode.*", # second table
"ProjectNode.*", # input columns for the first table
"TableSourceNode" # first table
"TableSourceNode.*", # second table
"TableSourceNode" # first table
)
)

Expand All @@ -539,11 +538,10 @@ test_that("show_exec_plan(), show_query() and explain()", {
arrange(desc(wt)) %>%
show_exec_plan(),
regexp = paste0(
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"OrderBySinkNode.*wt.*DESC.*", # arrange goes via the OrderBy sink node
"ProjectNode.*", # output columns
"FilterNode.*", # filter node
"TableSourceNode.*" # entry point
"FilterNode.*", # filter node
"TableSourceNode.*" # entry point
)
)

Expand All @@ -559,3 +557,27 @@ test_that("show_exec_plan(), show_query() and explain()", {
"The `ExecPlan` cannot be printed for a nested query."
)
})

test_that("needs_projection unit tests", {
tab <- Table$create(tbl)
# Wrapper to simplify tests
query_needs_projection <- function(query) {
needs_projection(query$selected_columns, tab$schema)
}
expect_false(query_needs_projection(as_adq(tab)))
expect_false(query_needs_projection(
tab %>% collapse() %>% collapse()
))
expect_true(query_needs_projection(
tab %>% mutate(int = int + 2)
))
expect_true(query_needs_projection(
tab %>% select(int, chr)
))
expect_true(query_needs_projection(
tab %>% rename(int2 = int)
))
expect_true(query_needs_projection(
tab %>% relocate(lgl)
))
})
41 changes: 39 additions & 2 deletions r/tests/testthat/test-dplyr-summarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ test_that("n_distinct() with many batches", {
write_parquet(dplyr::starwars, tf, chunk_size = 20)

ds <- open_dataset(tf)
expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(),
ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE)))
expect_equal(
ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(),
ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE))
)
})

test_that("n_distinct() on dataset", {
Expand Down Expand Up @@ -1089,3 +1091,38 @@ test_that("summarise() supports namespacing", {
tbl
)
})

test_that("We don't add unnecessary ProjectNodes when aggregating", {
tab <- Table$create(tbl)

# Wrapper to simplify the tests
expect_project_nodes <- function(query, n) {
plan <- capture.output(query %>% show_query())
expect_length(grep("ProjectNode", plan), n)
}

# 1 Projection: select int as `mean(int)` before aggregation
expect_project_nodes(
tab %>% summarize(mean(int)),
1
)

# 0 Projections only if
# (a) input only contains the col you're aggregating, and
# (b) the output col name is the same as the input name, and
# (c) no grouping
expect_project_nodes(
tab[, "int"] %>% summarize(int = mean(int, na.rm = TRUE)),
0
)

# 2 projections: one before, and one after in order to put grouping cols first
expect_project_nodes(
tab %>% group_by(lgl) %>% summarize(mean(int)),
2
)
expect_project_nodes(
tab %>% count(lgl),
2
)
})

0 comments on commit 80bba29

Please sign in to comment.