diff --git a/r/R/arrow-tabular.R b/r/R/arrow-tabular.R index 4682e89327f1b..43110ccf24eab 100644 --- a/r/R/arrow-tabular.R +++ b/r/R/arrow-tabular.R @@ -22,7 +22,12 @@ ArrowTabular <- R6Class("ArrowTabular", inherit = ArrowObject, public = list( - ToString = function() ToString_tabular(self), + ToString = function() { + sch <- unlist(strsplit(self$schema$ToString(), "\n")) + sch <- sub("(.*): (.*)", "$\\1 <\\2>", sch) + dims <- sprintf("%s rows x %s columns", self$num_rows, self$num_columns) + paste(c(dims, sch), collapse = "\n") + }, Take = function(i) { if (is.numeric(i)) { i <- as.integer(i) @@ -57,6 +62,39 @@ ArrowTabular <- R6Class("ArrowTabular", options = list(names = names, orders = as.integer(descending)) ) } + ), + active = list( + metadata = function(new) { + if (missing(new)) { + # Get the metadata (from the schema) + self$schema$metadata + } else { + # Set the metadata + new <- prepare_key_value_metadata(new) + out <- self$ReplaceSchemaMetadata(new) + # ReplaceSchemaMetadata returns a new object but we're modifying in place, + # so swap in that new C++ object pointer into our R6 object + self$set_pointer(out$pointer()) + self + } + }, + r_metadata = function(new) { + # Helper for the R metadata that handles the serialization + # See also method on Schema + if (missing(new)) { + out <- self$metadata$r + if (!is.null(out)) { + # Can't unserialize NULL + out <- .unserialize_arrow_r_metadata(out) + } + # Returns either NULL or a named list + out + } else { + # Set the R metadata + self$metadata$r <- .serialize_arrow_r_metadata(new) + self + } + } ) ) @@ -195,6 +233,9 @@ as.data.frame.ArrowTabular <- function(x, row.names = NULL, optional = FALSE, .. #' @export dim.ArrowTabular <- function(x) c(x$num_rows, x$num_columns) +#' @export +length.ArrowTabular <- function(x) x$num_columns + #' @export as.list.ArrowTabular <- function(x, ...) as.list(as.data.frame(x, ...)) @@ -229,14 +270,3 @@ na.omit.ArrowTabular <- function(object, ...) { #' @export na.exclude.ArrowTabular <- na.omit.ArrowTabular - -ToString_tabular <- function(x, ...) { - # Generic to work with both RecordBatch and Table - sch <- unlist(strsplit(x$schema$ToString(), "\n")) - sch <- sub("(.*): (.*)", "$\\1 <\\2>", sch) - dims <- sprintf("%s rows x %s columns", nrow(x), ncol(x)) - paste(c(dims, sch), collapse = "\n") -} - -#' @export -length.ArrowTabular <- function(x) x$num_columns diff --git a/r/R/dplyr-group-by.R b/r/R/dplyr-group-by.R index 83cc3abe3514e..735662e7461fd 100644 --- a/r/R/dplyr-group-by.R +++ b/r/R/dplyr-group-by.R @@ -61,7 +61,10 @@ groups.arrow_dplyr_query <- function(x) syms(dplyr::group_vars(x)) groups.Dataset <- groups.ArrowTabular <- function(x) NULL group_vars.arrow_dplyr_query <- function(x) x$group_by_vars -group_vars.Dataset <- group_vars.ArrowTabular <- function(x) NULL +group_vars.Dataset <- function(x) NULL +group_vars.ArrowTabular <- function(x) { + x$r_metadata$attributes$.group_vars +} # the logical literal in the two functions below controls the default value of # the .drop argument to group_by() @@ -75,4 +78,8 @@ ungroup.arrow_dplyr_query <- function(x, ...) { x$drop_empty_groups <- NULL x } -ungroup.Dataset <- ungroup.ArrowTabular <- force +ungroup.Dataset <- force +ungroup.ArrowTabular <- function(x) { + x$r_metadata$attributes$.group_vars <- NULL + x +} diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 954c0c5b31950..810785e0a7c6b 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -24,6 +24,7 @@ arrow_dplyr_query <- function(.data) { # RecordBatch, or Dataset) and the state of the user's dplyr query--things # like selected columns, filters, and group vars. # An arrow_dplyr_query can contain another arrow_dplyr_query in .data + gv <- dplyr::group_vars(.data) %||% character() if (!inherits(.data, c("Dataset", "arrow_dplyr_query"))) { .data <- InMemoryDataset$create(.data) } @@ -50,7 +51,7 @@ arrow_dplyr_query <- function(.data) { filtered_rows = TRUE, # group_by_vars is a character vector of columns (as renamed) # in the data. They will be kept when data is pulled into R. - group_by_vars = character(), + group_by_vars = gv, # drop_empty_groups is a logical value indicating whether to drop # groups formed by factor levels that don't appear in the data. It # should be non-null only when the data is grouped. diff --git a/r/R/metadata.R b/r/R/metadata.R index fc1f0229dc85a..768abeda72fc6 100644 --- a/r/R/metadata.R +++ b/r/R/metadata.R @@ -99,6 +99,10 @@ apply_arrow_r_metadata <- function(x, r_metadata) { # of class data.frame, remove the extraneous attribute attr(x, "row.names") <- NULL } + if (!is.null(attr(x, ".group_vars")) && requireNamespace("dplyr", quietly = TRUE)) { + x <- dplyr::group_by(x, !!!syms(attr(x, ".group_vars"))) + attr(x, ".group_vars") <- NULL + } } }, error = function(e) { @@ -129,6 +133,19 @@ remove_attributes <- function(x) { } arrow_attributes <- function(x, only_top_level = FALSE) { + if (inherits(x, "grouped_df")) { + # Keep only the group var names, not the rest of the cached data that dplyr + # uses, which may be large + if (requireNamespace("dplyr", quietly = TRUE)) { + gv <- dplyr::group_vars(x) + x <- dplyr::ungroup(x) + # ungroup() first, then set attribute, bc ungroup() would erase it + attr(x, ".group_vars") <- gv + } else { + # Regardless, we shouldn't keep groups around + attr(x, "groups") <- NULL + } + } att <- attributes(x) removed_attributes <- remove_attributes(x) diff --git a/r/R/query-engine.R b/r/R/query-engine.R index aa9740196bbf9..335e1b72d7b36 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -37,9 +37,8 @@ do_exec_plan <- function(.data) { original_schema <- source_data(.data)$schema # TODO: do we care about other (non-R) metadata preservation? # How would we know if it were meaningful? - r_meta <- original_schema$metadata$r + r_meta <- original_schema$r_metadata if (!is.null(r_meta)) { - r_meta <- .unserialize_arrow_r_metadata(r_meta) # Filter r_metadata$columns on columns with name _and_ type match new_schema <- tab$schema common_names <- intersect(names(r_meta$columns), names(tab)) @@ -51,7 +50,7 @@ do_exec_plan <- function(.data) { # dplyr drops top-level attributes if you do summarize r_meta$attributes <- NULL } - tab$metadata$r <- .serialize_arrow_r_metadata(r_meta) + tab$r_metadata <- r_meta } } diff --git a/r/R/record-batch.R b/r/R/record-batch.R index e1c5251b25423..c66ff7fb0f7cd 100644 --- a/r/R/record-batch.R +++ b/r/R/record-batch.R @@ -99,6 +99,9 @@ RecordBatch <- R6Class("RecordBatch", RecordBatch__SetColumn(self, i, new_field, value) }, RemoveColumn = function(i) RecordBatch__RemoveColumn(self, i), + ReplaceSchemaMetadata = function(new) { + RecordBatch__ReplaceSchemaMetadata(self, new) + }, Slice = function(offset, length = NULL) { if (is.null(length)) { RecordBatch__Slice1(self, offset) @@ -128,20 +131,6 @@ RecordBatch <- R6Class("RecordBatch", num_columns = function() RecordBatch__num_columns(self), num_rows = function() RecordBatch__num_rows(self), schema = function() RecordBatch__schema(self), - metadata = function(new) { - if (missing(new)) { - # Get the metadata (from the schema) - self$schema$metadata - } else { - # Set the metadata - new <- prepare_key_value_metadata(new) - out <- RecordBatch__ReplaceSchemaMetadata(self, new) - # ReplaceSchemaMetadata returns a new object but we're modifying in place, - # so swap in that new C++ object pointer into our R6 object - self$set_pointer(out$pointer()) - self - } - }, columns = function() RecordBatch__columns(self) ) ) @@ -159,12 +148,6 @@ RecordBatch$create <- function(..., schema = NULL) { } stopifnot(length(arrays) > 0) - # Preserve any grouping - if (length(arrays) == 1 && inherits(arrays[[1]], "grouped_df")) { - out <- RecordBatch__from_arrays(schema, arrays) - return(dplyr::group_by(out, !!!dplyr::groups(arrays[[1]]))) - } - # If any arrays are length 1, recycle them arrays <- recycle_scalars(arrays) diff --git a/r/R/schema.R b/r/R/schema.R index e4ea618fd6d7d..f46bd09d91990 100644 --- a/r/R/schema.R +++ b/r/R/schema.R @@ -133,6 +133,23 @@ Schema <- R6Class("Schema", self$set_pointer(out$pointer()) self } + }, + r_metadata = function(new) { + # Helper for the R metadata that handles the serialization + # See also method on ArrowTabular + if (missing(new)) { + out <- self$metadata$r + if (!is.null(out)) { + # Can't unserialize NULL + out <- .unserialize_arrow_r_metadata(out) + } + # Returns either NULL or a named list + out + } else { + # Set the R metadata + self$metadata$r <- .serialize_arrow_r_metadata(new) + self + } } ) ) diff --git a/r/R/table.R b/r/R/table.R index 5aae067f0fc10..fcd599e07acd0 100644 --- a/r/R/table.R +++ b/r/R/table.R @@ -108,6 +108,9 @@ Table <- R6Class("Table", RemoveColumn = function(i) Table__RemoveColumn(self, i), AddColumn = function(i, new_field, value) Table__AddColumn(self, i, new_field, value), SetColumn = function(i, new_field, value) Table__SetColumn(self, i, new_field, value), + ReplaceSchemaMetadata = function(new) { + Table__ReplaceSchemaMetadata(self, new) + }, field = function(i) Table__field(self, i), serialize = function(output_stream, ...) write_table(self, output_stream, ...), to_data_frame = function() { @@ -141,20 +144,6 @@ Table <- R6Class("Table", num_columns = function() Table__num_columns(self), num_rows = function() Table__num_rows(self), schema = function() Table__schema(self), - metadata = function(new) { - if (missing(new)) { - # Get the metadata (from the schema) - self$schema$metadata - } else { - # Set the metadata - new <- prepare_key_value_metadata(new) - out <- Table__ReplaceSchemaMetadata(self, new) - # ReplaceSchemaMetadata returns a new object but we're modifying in place, - # so swap in that new C++ object pointer into our R6 object - self$set_pointer(out$pointer()) - self - } - }, columns = function() Table__columns(self) ) ) @@ -174,13 +163,7 @@ Table$create <- function(..., schema = NULL) { # If any arrays are length 1, recycle them dots <- recycle_scalars(dots) - out <- Table__from_dots(dots, schema, option_use_threads()) - - # Preserve any grouping - if (length(dots) == 1 && inherits(dots[[1]], "grouped_df")) { - out <- dplyr::group_by(out, !!!dplyr::groups(dots[[1]])) - } - out + Table__from_dots(dots, schema, option_use_threads()) } #' @export diff --git a/r/tests/testthat/helper-parquet.R b/r/tests/testthat/helper-parquet.R index 7697d24d39ddb..a0dd445bb8939 100644 --- a/r/tests/testthat/helper-parquet.R +++ b/r/tests/testthat/helper-parquet.R @@ -19,11 +19,11 @@ expect_parquet_roundtrip <- function(tab, ...) { expect_equal(parquet_roundtrip(tab, ...), tab) } -parquet_roundtrip <- function(x, ...) { +parquet_roundtrip <- function(x, ..., as_data_frame = FALSE) { # write/read parquet, returns Table tf <- tempfile() on.exit(unlink(tf)) write_parquet(x, tf, ...) - read_parquet(tf, as_data_frame = FALSE) + read_parquet(tf, as_data_frame = as_data_frame) } diff --git a/r/tests/testthat/test-RecordBatch.R b/r/tests/testthat/test-RecordBatch.R index f80ccd979ab3c..dfe4067f73cb6 100644 --- a/r/tests/testthat/test-RecordBatch.R +++ b/r/tests/testthat/test-RecordBatch.R @@ -553,8 +553,9 @@ test_that("Handling string data with embedded nuls", { }) }) -test_that("ARROW-11769 - grouping preserved in record batch creation", { +test_that("ARROW-11769/ARROW-13860 - grouping preserved in record batch creation", { skip_if_not_available("dataset") + library(dplyr, warn.conflicts = FALSE) tbl <- tibble::tibble( int = 1:10, @@ -562,11 +563,33 @@ test_that("ARROW-11769 - grouping preserved in record batch creation", { fct2 = factor(rep(c("C", "D"), each = 5)), ) + expect_r6_class( + tbl %>% + group_by(fct, fct2) %>% + record_batch(), + "RecordBatch" + ) + expect_identical( + tbl %>% + group_by(fct, fct2) %>% + record_batch() %>% + group_vars(), + c("fct", "fct2") + ) + expect_identical( + tbl %>% + group_by(fct, fct2) %>% + record_batch() %>% + ungroup() %>% + group_vars(), + NULL + ) expect_identical( tbl %>% - dplyr::group_by(fct, fct2) %>% + group_by(fct, fct2) %>% record_batch() %>% - dplyr::group_vars(), + select(-int) %>% + group_vars(), c("fct", "fct2") ) }) diff --git a/r/tests/testthat/test-duckdb.R b/r/tests/testthat/test-duckdb.R index 66ba68ee85674..a30c541997792 100644 --- a/r/tests/testthat/test-duckdb.R +++ b/r/tests/testthat/test-duckdb.R @@ -20,8 +20,8 @@ skip_if_not_installed("dbplyr") skip_if_not_available("dataset") skip_on_cran() -library(duckdb) -library(dplyr) +library(duckdb, quietly = TRUE) +library(dplyr, warn.conflicts = FALSE) test_that("to_duckdb", { ds <- InMemoryDataset$create(example_data) diff --git a/r/tests/testthat/test-metadata.R b/r/tests/testthat/test-metadata.R index f70dcd64f7234..c560da77081df 100644 --- a/r/tests/testthat/test-metadata.R +++ b/r/tests/testthat/test-metadata.R @@ -219,7 +219,8 @@ test_that("Row-level metadata (does not by default) roundtrip", { # metadata should be handled separately ARROW-14020, ARROW-12542 df <- data.frame(x = I(list(structure(1, foo = "bar"), structure(2, baz = "qux")))) tab <- Table$create(df) - r_metadata <- .unserialize_arrow_r_metadata(tab$metadata$r) + r_metadata <- tab$r_metadata + expect_type(r_metadata, "list") expect_null(r_metadata$columns$x$columns) # But we can re-enable this / read data that has already been written with @@ -239,8 +240,7 @@ test_that("Row-level metadata (does not) roundtrip in datasets", { skip_if_not_available("dataset") skip_if_not_available("parquet") - local_edition(3) - library(dplyr) + library(dplyr, warn.conflicts = FALSE) df <- tibble::tibble( metadata = list( @@ -359,3 +359,11 @@ test_that("dplyr with metadata", { example_with_metadata ) }) + +test_that("grouped_df metadata is recorded (efficiently)", { + grouped <- group_by(tibble(a = 1:2, b = 3:4), a) + expect_s3_class(grouped, "grouped_df") + grouped_tab <- Table$create(grouped) + expect_r6_class(grouped_tab, "Table") + expect_equal(grouped_tab$r_metadata$attributes$.group_vars, "a") +}) diff --git a/r/tests/testthat/test-parquet.R b/r/tests/testthat/test-parquet.R index 791c7b61cca11..55d86b532b02b 100644 --- a/r/tests/testthat/test-parquet.R +++ b/r/tests/testthat/test-parquet.R @@ -104,6 +104,13 @@ test_that("write_parquet() accepts RecordBatch too", { expect_equal(tab, Table$create(batch)) }) +test_that("write_parquet() handles grouped_df", { + library(dplyr, warn.conflicts = FALSE) + df <- tibble::tibble(a = 1:4, b = 5) %>% group_by(b) + # Since `df` is a "grouped_df", this test asserts that we get a grouped_df back + expect_parquet_roundtrip(df, as_data_frame = TRUE) +}) + test_that("write_parquet() with invalid input type", { bad_input <- Array$create(1:5) expect_error(