diff --git a/r/DESCRIPTION b/r/DESCRIPTION index f37e6a4e84f9d..3de40f6f9a7e9 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -30,6 +30,7 @@ Imports: purrr, R6, rlang, + stats, tidyselect, utils, vctrs diff --git a/r/NAMESPACE b/r/NAMESPACE index 49d845805f402..fb3ea82c4af42 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -58,6 +58,7 @@ S3method(match_arrow,ArrowDatum) S3method(match_arrow,default) S3method(max,ArrowDatum) S3method(mean,ArrowDatum) +S3method(median,ArrowDatum) S3method(min,ArrowDatum) S3method(names,Dataset) S3method(names,FeatherReader) @@ -73,6 +74,7 @@ S3method(print,array_expression) S3method(print,arrow_dplyr_query) S3method(print,arrow_info) S3method(print,arrow_r_metadata) +S3method(quantile,ArrowDatum) S3method(read_message,InputStream) S3method(read_message,MessageReader) S3method(read_message,default) @@ -152,6 +154,7 @@ export(ParquetFileWriter) export(ParquetVersionType) export(ParquetWriterProperties) export(Partitioning) +export(QuantileInterpolation) export(RandomAccessFile) export(ReadableFile) export(RecordBatchFileReader) @@ -308,6 +311,8 @@ importFrom(rlang,seq2) importFrom(rlang,set_names) importFrom(rlang,syms) importFrom(rlang,warn) +importFrom(stats,median) +importFrom(stats,quantile) importFrom(tidyselect,contains) importFrom(tidyselect,ends_with) importFrom(tidyselect,eval_select) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 2515e7ac9208e..10dae65fd3f7c 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +#' @importFrom stats quantile median #' @importFrom R6 R6Class #' @importFrom purrr as_mapper map map2 map_chr map_dfr map_int map_lgl keep #' @importFrom assertthat assert_that is.string diff --git a/r/R/compute.R b/r/R/compute.R index 09f2c653a8aec..749da6d52773b 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -80,6 +80,49 @@ collect_arrays_from_dots <- function(dots) { ChunkedArray$create(!!!arrays) } +#' @export +quantile.ArrowDatum <- function(x, + probs = seq(0, 1, 0.25), + na.rm = FALSE, + type = 7, + interpolation = c("linear", "lower", "higher", "nearest", "midpoint"), + ...) { + if (inherits(x, "Scalar")) x <- Array$create(x) + assert_is(probs, c("numeric", "integer")) + assert_that(length(probs) > 0) + assert_that(all(probs >= 0 & probs <= 1)) + if (!na.rm && x$null_count > 0) { + stop("Missing values not allowed if 'na.rm' is FALSE", call. = FALSE) + } + if (type != 7) { + stop( + "Argument `type` not supported in Arrow. To control the quantile ", + "interpolation algorithm, set argument `interpolation` to one of: ", + "\"linear\" (the default), \"lower\", \"higher\", \"nearest\", or ", + "\"midpoint\".", + call. = FALSE + ) + } + interpolation <- QuantileInterpolation[[toupper(match.arg(interpolation))]] + out <- call_function("quantile", x, options = list(q = probs, interpolation = interpolation)) + if (length(out) == 0) { + # When there are no non-missing values in the data, the Arrow quantile + # function returns an empty Array, but for consistency with the R quantile + # function, we want an Array of NA_real_ with the same length as probs + out <- Array$create(rep(NA_real_, length(probs))) + } + out +} + +#' @export +median.ArrowDatum <- function(x, na.rm = FALSE, ...) { + if (!na.rm && x$null_count > 0) { + Scalar$create(NA_real_) + } else { + Scalar$create(quantile(x, probs = 0.5, na.rm = TRUE, ...)) + } +} + #' @export unique.ArrowDatum <- function(x, incomparables = FALSE, ...) { call_function("unique", x) diff --git a/r/R/enums.R b/r/R/enums.R index 14910bc92e03a..170abf998657e 100644 --- a/r/R/enums.R +++ b/r/R/enums.R @@ -128,3 +128,9 @@ ParquetVersionType <- enum("ParquetVersionType", MetadataVersion <- enum("MetadataVersion", V1 = 0L, V2 = 1L, V3 = 2L, V4 = 3L, V5 = 4L ) + +#' @export +#' @rdname enums +QuantileInterpolation <- enum("QuantileInterpolation", + LINEAR = 0L, LOWER = 1L, HIGHER = 2L, NEAREST = 3L, MIDPOINT = 4L +) diff --git a/r/man/enums.Rd b/r/man/enums.Rd index e4cb2d854697b..fa3c64b8f955e 100644 --- a/r/man/enums.Rd +++ b/r/man/enums.Rd @@ -13,6 +13,7 @@ \alias{FileType} \alias{ParquetVersionType} \alias{MetadataVersion} +\alias{QuantileInterpolation} \title{Arrow enums} \format{ An object of class \code{TimeUnit::type} (inherits from \code{arrow-enum}) of length 4. @@ -34,6 +35,8 @@ An object of class \code{FileType} (inherits from \code{arrow-enum}) of length 4 An object of class \code{ParquetVersionType} (inherits from \code{arrow-enum}) of length 2. An object of class \code{MetadataVersion} (inherits from \code{arrow-enum}) of length 5. + +An object of class \code{QuantileInterpolation} (inherits from \code{arrow-enum}) of length 5. } \usage{ TimeUnit @@ -55,6 +58,8 @@ FileType ParquetVersionType MetadataVersion + +QuantileInterpolation } \description{ Arrow enums diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 61ccf07d07bbc..9600eb0d62102 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -179,6 +179,23 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "quantile") { + using Options = arrow::compute::QuantileOptions; + auto out = std::make_shared(Options::Defaults()); + SEXP q = options["q"]; + if (!Rf_isNull(q) && TYPEOF(q) == REALSXP) { + out->q = cpp11::as_cpp>(q); + } + SEXP interpolation = options["interpolation"]; + if (!Rf_isNull(interpolation) && TYPEOF(interpolation) == INTSXP && + XLENGTH(interpolation) == 1) { + out->interpolation = + cpp11::as_cpp( + interpolation); + } + return out; + } + if (func_name == "is_in" || func_name == "index_in") { using Options = arrow::compute::SetLookupOptions; return std::make_shared(cpp11::as_cpp(options["value_set"]), diff --git a/r/tests/testthat/test-compute-aggregate.R b/r/tests/testthat/test-compute-aggregate.R index 5f4eaba49bd2c..32acffb7bebee 100644 --- a/r/tests/testthat/test-compute-aggregate.R +++ b/r/tests/testthat/test-compute-aggregate.R @@ -199,6 +199,113 @@ test_that("Edge cases", { } }) +test_that("quantile.Array and quantile.ChunkedArray", { + a <- Array$create(c(0, 1, 2, 3)) + ca <- ChunkedArray$create(c(0, 1), c(2, 3)) + probs <- c(0.49, 0.51) + for(ad in list(a, ca)) { + for (type in c(int32(), uint64(), float64())) { + expect_equal( + quantile(ad$cast(type), probs = probs, interpolation = "linear"), + Array$create(c(1.47, 1.53)) + ) + expect_equal( + quantile(ad$cast(type), probs = probs, interpolation = "lower"), + Array$create(c(1, 1))$cast(type) + ) + expect_equal( + quantile(ad$cast(type), probs = probs, interpolation = "higher"), + Array$create(c(2, 2))$cast(type) + ) + expect_equal( + quantile(ad$cast(type), probs = probs, interpolation = "nearest"), + Array$create(c(1, 2))$cast(type) + ) + expect_equal( + quantile(ad$cast(type), probs = probs, interpolation = "midpoint"), + Array$create(c(1.5, 1.5)) + ) + } + } +}) + +test_that("quantile and median NAs, edge cases, and exceptions", { + expect_equal( + quantile(Array$create(c(1, 2)), probs = c(0, 1)), + Array$create(c(1, 2)) + ) + expect_error( + quantile(Array$create(c(1, 2, NA))), + "Missing values not allowed if 'na.rm' is FALSE" + ) + expect_equal( + quantile(Array$create(numeric(0))), + Array$create(rep(NA_real_, 5)) + ) + expect_equal( + quantile(Array$create(rep(NA_integer_, 3)), na.rm = TRUE), + Array$create(rep(NA_real_, 5)) + ) + expect_error( + median(Array$create(c(1, 2)), probs = c(.25, .75)) + ) + expect_equal( + median(Array$create(c(1, 2)), interpolation = "higher"), + Scalar$create(2) + ) + expect_equal( + quantile(Scalar$create(0L)), + Array$create(rep(0, 5)) + ) + expect_equal( + median(Scalar$create(1L)), + Scalar$create(1) + ) + expect_error( + quantile(Array$create(1:3), type = 9), + "not supported" + ) +}) + +test_that("median.Array and median.ChunkedArray", { + expect_vector_equal( + median(input), + 1:4 + ) + expect_vector_equal( + median(input), + 1:5 + ) + expect_vector_equal( + median(input), + numeric(0) + ) + expect_vector_equal( + median(input, na.rm = FALSE), + c(1, 2, NA) + ) + expect_vector_equal( + median(input, na.rm = TRUE), + c(1, 2, NA) + ) + expect_vector_equal( + median(input, na.rm = TRUE), + NA_real_ + ) + expect_vector_equal( + median(input, na.rm = FALSE), + c(1, 2, NA) + ) + expect_vector_equal( + median(input, na.rm = TRUE), + c(1, 2, NA) + ) + expect_vector_equal( + median(input, na.rm = TRUE), + NA_real_ + ) +}) + test_that("unique.Array", { a <- Array$create(c(1, 4, 3, 1, 1, 3, 4)) expect_equal(unique(a), Array$create(c(1, 4, 3)))