Skip to content

Commit

Permalink
ARROW-13860: [R] arrow 5.0.0 write_parquet throws error writing group…
Browse files Browse the repository at this point in the history
…ed data.frame

* `Table/RecordBatch$create()` on `grouped_df` no longer returns an `arrow_dplyr_query`, which was the change in the last release. This means these functions are type stable again, and this fixes the user report that write_parquet() doesn't work.
* Instead of creating `arrow_dplyr_query`, group vars are stored in a special `.group_vars` attribute in the `metadata$r`. This attribute is used to restore groups on the round trip back to R, so `grouped_df %>% record_batch() %>% as.data.frame()` returns a `grouped_df`
* The current dplyr release caches a lot of metadata about groups in a `grouped_df`, including all row indices matching each group value. This bloated the schema metadata we serialize, so it has been removed here. When converting back to a `grouped_df`/`data.frame`, dplyr will recreate this metadata.
* The `group_vars()` and `ungroup()` methods for `ArrowTabular` read/write this new `metadata$r$attributes$.group_vars` field, so `df %>% group_by() %>% record_batch() %>% group_vars()` returns the same as `df %>% record_batch() %>% group_by() %>% group_vars()`. `arrow_dplyr_query()` also picks up on it.
* New helper active binding `$r_metadata` to wrap the (de)serialization into the Arrow string KeyValueMetadata

Closes apache#11315 from nealrichardson/fix-grouped-df

Authored-by: Neal Richardson <[email protected]>
Signed-off-by: Jonathan Keane <[email protected]>
  • Loading branch information
nealrichardson authored and jonkeane committed Oct 14, 2021
1 parent 5845556 commit 7eba115
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 69 deletions.
54 changes: 42 additions & 12 deletions r/R/arrow-tabular.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
ArrowTabular <- R6Class("ArrowTabular",
inherit = ArrowObject,
public = list(
ToString = function() ToString_tabular(self),
ToString = function() {
sch <- unlist(strsplit(self$schema$ToString(), "\n"))
sch <- sub("(.*): (.*)", "$\\1 <\\2>", sch)
dims <- sprintf("%s rows x %s columns", self$num_rows, self$num_columns)
paste(c(dims, sch), collapse = "\n")
},
Take = function(i) {
if (is.numeric(i)) {
i <- as.integer(i)
Expand Down Expand Up @@ -57,6 +62,39 @@ ArrowTabular <- R6Class("ArrowTabular",
options = list(names = names, orders = as.integer(descending))
)
}
),
active = list(
metadata = function(new) {
if (missing(new)) {
# Get the metadata (from the schema)
self$schema$metadata
} else {
# Set the metadata
new <- prepare_key_value_metadata(new)
out <- self$ReplaceSchemaMetadata(new)
# ReplaceSchemaMetadata returns a new object but we're modifying in place,
# so swap in that new C++ object pointer into our R6 object
self$set_pointer(out$pointer())
self
}
},
r_metadata = function(new) {
# Helper for the R metadata that handles the serialization
# See also method on Schema
if (missing(new)) {
out <- self$metadata$r
if (!is.null(out)) {
# Can't unserialize NULL
out <- .unserialize_arrow_r_metadata(out)
}
# Returns either NULL or a named list
out
} else {
# Set the R metadata
self$metadata$r <- .serialize_arrow_r_metadata(new)
self
}
}
)
)

Expand Down Expand Up @@ -195,6 +233,9 @@ as.data.frame.ArrowTabular <- function(x, row.names = NULL, optional = FALSE, ..
#' @export
dim.ArrowTabular <- function(x) c(x$num_rows, x$num_columns)

#' @export
length.ArrowTabular <- function(x) x$num_columns

#' @export
as.list.ArrowTabular <- function(x, ...) as.list(as.data.frame(x, ...))

Expand Down Expand Up @@ -229,14 +270,3 @@ na.omit.ArrowTabular <- function(object, ...) {

#' @export
na.exclude.ArrowTabular <- na.omit.ArrowTabular

ToString_tabular <- function(x, ...) {
# Generic to work with both RecordBatch and Table
sch <- unlist(strsplit(x$schema$ToString(), "\n"))
sch <- sub("(.*): (.*)", "$\\1 <\\2>", sch)
dims <- sprintf("%s rows x %s columns", nrow(x), ncol(x))
paste(c(dims, sch), collapse = "\n")
}

#' @export
length.ArrowTabular <- function(x) x$num_columns
11 changes: 9 additions & 2 deletions r/R/dplyr-group-by.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ groups.arrow_dplyr_query <- function(x) syms(dplyr::group_vars(x))
groups.Dataset <- groups.ArrowTabular <- function(x) NULL

group_vars.arrow_dplyr_query <- function(x) x$group_by_vars
group_vars.Dataset <- group_vars.ArrowTabular <- function(x) NULL
group_vars.Dataset <- function(x) NULL
group_vars.ArrowTabular <- function(x) {
x$r_metadata$attributes$.group_vars
}

# the logical literal in the two functions below controls the default value of
# the .drop argument to group_by()
Expand All @@ -75,4 +78,8 @@ ungroup.arrow_dplyr_query <- function(x, ...) {
x$drop_empty_groups <- NULL
x
}
ungroup.Dataset <- ungroup.ArrowTabular <- force
ungroup.Dataset <- force
ungroup.ArrowTabular <- function(x) {
x$r_metadata$attributes$.group_vars <- NULL
x
}
3 changes: 2 additions & 1 deletion r/R/dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ arrow_dplyr_query <- function(.data) {
# RecordBatch, or Dataset) and the state of the user's dplyr query--things
# like selected columns, filters, and group vars.
# An arrow_dplyr_query can contain another arrow_dplyr_query in .data
gv <- dplyr::group_vars(.data) %||% character()
if (!inherits(.data, c("Dataset", "arrow_dplyr_query"))) {
.data <- InMemoryDataset$create(.data)
}
Expand All @@ -50,7 +51,7 @@ arrow_dplyr_query <- function(.data) {
filtered_rows = TRUE,
# group_by_vars is a character vector of columns (as renamed)
# in the data. They will be kept when data is pulled into R.
group_by_vars = character(),
group_by_vars = gv,
# drop_empty_groups is a logical value indicating whether to drop
# groups formed by factor levels that don't appear in the data. It
# should be non-null only when the data is grouped.
Expand Down
17 changes: 17 additions & 0 deletions r/R/metadata.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ apply_arrow_r_metadata <- function(x, r_metadata) {
# of class data.frame, remove the extraneous attribute
attr(x, "row.names") <- NULL
}
if (!is.null(attr(x, ".group_vars")) && requireNamespace("dplyr", quietly = TRUE)) {
x <- dplyr::group_by(x, !!!syms(attr(x, ".group_vars")))
attr(x, ".group_vars") <- NULL
}
}
},
error = function(e) {
Expand Down Expand Up @@ -129,6 +133,19 @@ remove_attributes <- function(x) {
}

arrow_attributes <- function(x, only_top_level = FALSE) {
if (inherits(x, "grouped_df")) {
# Keep only the group var names, not the rest of the cached data that dplyr
# uses, which may be large
if (requireNamespace("dplyr", quietly = TRUE)) {
gv <- dplyr::group_vars(x)
x <- dplyr::ungroup(x)
# ungroup() first, then set attribute, bc ungroup() would erase it
attr(x, ".group_vars") <- gv
} else {
# Regardless, we shouldn't keep groups around
attr(x, "groups") <- NULL
}
}
att <- attributes(x)

removed_attributes <- remove_attributes(x)
Expand Down
5 changes: 2 additions & 3 deletions r/R/query-engine.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ do_exec_plan <- function(.data) {
original_schema <- source_data(.data)$schema
# TODO: do we care about other (non-R) metadata preservation?
# How would we know if it were meaningful?
r_meta <- original_schema$metadata$r
r_meta <- original_schema$r_metadata
if (!is.null(r_meta)) {
r_meta <- .unserialize_arrow_r_metadata(r_meta)
# Filter r_metadata$columns on columns with name _and_ type match
new_schema <- tab$schema
common_names <- intersect(names(r_meta$columns), names(tab))
Expand All @@ -51,7 +50,7 @@ do_exec_plan <- function(.data) {
# dplyr drops top-level attributes if you do summarize
r_meta$attributes <- NULL
}
tab$metadata$r <- .serialize_arrow_r_metadata(r_meta)
tab$r_metadata <- r_meta
}
}

Expand Down
23 changes: 3 additions & 20 deletions r/R/record-batch.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ RecordBatch <- R6Class("RecordBatch",
RecordBatch__SetColumn(self, i, new_field, value)
},
RemoveColumn = function(i) RecordBatch__RemoveColumn(self, i),
ReplaceSchemaMetadata = function(new) {
RecordBatch__ReplaceSchemaMetadata(self, new)
},
Slice = function(offset, length = NULL) {
if (is.null(length)) {
RecordBatch__Slice1(self, offset)
Expand Down Expand Up @@ -128,20 +131,6 @@ RecordBatch <- R6Class("RecordBatch",
num_columns = function() RecordBatch__num_columns(self),
num_rows = function() RecordBatch__num_rows(self),
schema = function() RecordBatch__schema(self),
metadata = function(new) {
if (missing(new)) {
# Get the metadata (from the schema)
self$schema$metadata
} else {
# Set the metadata
new <- prepare_key_value_metadata(new)
out <- RecordBatch__ReplaceSchemaMetadata(self, new)
# ReplaceSchemaMetadata returns a new object but we're modifying in place,
# so swap in that new C++ object pointer into our R6 object
self$set_pointer(out$pointer())
self
}
},
columns = function() RecordBatch__columns(self)
)
)
Expand All @@ -159,12 +148,6 @@ RecordBatch$create <- function(..., schema = NULL) {
}
stopifnot(length(arrays) > 0)

# Preserve any grouping
if (length(arrays) == 1 && inherits(arrays[[1]], "grouped_df")) {
out <- RecordBatch__from_arrays(schema, arrays)
return(dplyr::group_by(out, !!!dplyr::groups(arrays[[1]])))
}

# If any arrays are length 1, recycle them
arrays <- recycle_scalars(arrays)

Expand Down
17 changes: 17 additions & 0 deletions r/R/schema.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,23 @@ Schema <- R6Class("Schema",
self$set_pointer(out$pointer())
self
}
},
r_metadata = function(new) {
# Helper for the R metadata that handles the serialization
# See also method on ArrowTabular
if (missing(new)) {
out <- self$metadata$r
if (!is.null(out)) {
# Can't unserialize NULL
out <- .unserialize_arrow_r_metadata(out)
}
# Returns either NULL or a named list
out
} else {
# Set the R metadata
self$metadata$r <- .serialize_arrow_r_metadata(new)
self
}
}
)
)
Expand Down
25 changes: 4 additions & 21 deletions r/R/table.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ Table <- R6Class("Table",
RemoveColumn = function(i) Table__RemoveColumn(self, i),
AddColumn = function(i, new_field, value) Table__AddColumn(self, i, new_field, value),
SetColumn = function(i, new_field, value) Table__SetColumn(self, i, new_field, value),
ReplaceSchemaMetadata = function(new) {
Table__ReplaceSchemaMetadata(self, new)
},
field = function(i) Table__field(self, i),
serialize = function(output_stream, ...) write_table(self, output_stream, ...),
to_data_frame = function() {
Expand Down Expand Up @@ -141,20 +144,6 @@ Table <- R6Class("Table",
num_columns = function() Table__num_columns(self),
num_rows = function() Table__num_rows(self),
schema = function() Table__schema(self),
metadata = function(new) {
if (missing(new)) {
# Get the metadata (from the schema)
self$schema$metadata
} else {
# Set the metadata
new <- prepare_key_value_metadata(new)
out <- Table__ReplaceSchemaMetadata(self, new)
# ReplaceSchemaMetadata returns a new object but we're modifying in place,
# so swap in that new C++ object pointer into our R6 object
self$set_pointer(out$pointer())
self
}
},
columns = function() Table__columns(self)
)
)
Expand All @@ -174,13 +163,7 @@ Table$create <- function(..., schema = NULL) {
# If any arrays are length 1, recycle them
dots <- recycle_scalars(dots)

out <- Table__from_dots(dots, schema, option_use_threads())

# Preserve any grouping
if (length(dots) == 1 && inherits(dots[[1]], "grouped_df")) {
out <- dplyr::group_by(out, !!!dplyr::groups(dots[[1]]))
}
out
Table__from_dots(dots, schema, option_use_threads())
}

#' @export
Expand Down
4 changes: 2 additions & 2 deletions r/tests/testthat/helper-parquet.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ expect_parquet_roundtrip <- function(tab, ...) {
expect_equal(parquet_roundtrip(tab, ...), tab)
}

parquet_roundtrip <- function(x, ...) {
parquet_roundtrip <- function(x, ..., as_data_frame = FALSE) {
# write/read parquet, returns Table
tf <- tempfile()
on.exit(unlink(tf))

write_parquet(x, tf, ...)
read_parquet(tf, as_data_frame = FALSE)
read_parquet(tf, as_data_frame = as_data_frame)
}
29 changes: 26 additions & 3 deletions r/tests/testthat/test-RecordBatch.R
Original file line number Diff line number Diff line change
Expand Up @@ -553,20 +553,43 @@ test_that("Handling string data with embedded nuls", {
})
})

test_that("ARROW-11769 - grouping preserved in record batch creation", {
test_that("ARROW-11769/ARROW-13860 - grouping preserved in record batch creation", {
skip_if_not_available("dataset")
library(dplyr, warn.conflicts = FALSE)

tbl <- tibble::tibble(
int = 1:10,
fct = factor(rep(c("A", "B"), 5)),
fct2 = factor(rep(c("C", "D"), each = 5)),
)

expect_r6_class(
tbl %>%
group_by(fct, fct2) %>%
record_batch(),
"RecordBatch"
)
expect_identical(
tbl %>%
group_by(fct, fct2) %>%
record_batch() %>%
group_vars(),
c("fct", "fct2")
)
expect_identical(
tbl %>%
group_by(fct, fct2) %>%
record_batch() %>%
ungroup() %>%
group_vars(),
NULL
)
expect_identical(
tbl %>%
dplyr::group_by(fct, fct2) %>%
group_by(fct, fct2) %>%
record_batch() %>%
dplyr::group_vars(),
select(-int) %>%
group_vars(),
c("fct", "fct2")
)
})
Expand Down
4 changes: 2 additions & 2 deletions r/tests/testthat/test-duckdb.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ skip_if_not_installed("dbplyr")
skip_if_not_available("dataset")
skip_on_cran()

library(duckdb)
library(dplyr)
library(duckdb, quietly = TRUE)
library(dplyr, warn.conflicts = FALSE)

test_that("to_duckdb", {
ds <- InMemoryDataset$create(example_data)
Expand Down
14 changes: 11 additions & 3 deletions r/tests/testthat/test-metadata.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ test_that("Row-level metadata (does not by default) roundtrip", {
# metadata should be handled separately ARROW-14020, ARROW-12542
df <- data.frame(x = I(list(structure(1, foo = "bar"), structure(2, baz = "qux"))))
tab <- Table$create(df)
r_metadata <- .unserialize_arrow_r_metadata(tab$metadata$r)
r_metadata <- tab$r_metadata
expect_type(r_metadata, "list")
expect_null(r_metadata$columns$x$columns)

# But we can re-enable this / read data that has already been written with
Expand All @@ -239,8 +240,7 @@ test_that("Row-level metadata (does not) roundtrip in datasets", {
skip_if_not_available("dataset")
skip_if_not_available("parquet")

local_edition(3)
library(dplyr)
library(dplyr, warn.conflicts = FALSE)

df <- tibble::tibble(
metadata = list(
Expand Down Expand Up @@ -359,3 +359,11 @@ test_that("dplyr with metadata", {
example_with_metadata
)
})

test_that("grouped_df metadata is recorded (efficiently)", {
grouped <- group_by(tibble(a = 1:2, b = 3:4), a)
expect_s3_class(grouped, "grouped_df")
grouped_tab <- Table$create(grouped)
expect_r6_class(grouped_tab, "Table")
expect_equal(grouped_tab$r_metadata$attributes$.group_vars, "a")
})
Loading

0 comments on commit 7eba115

Please sign in to comment.