Skip to content

Commit

Permalink
[SPARK-24331][SPARKR][SQL] Adding arrays_overlap, array_repeat, map_e…
Browse files Browse the repository at this point in the history
…ntries to SparkR

## What changes were proposed in this pull request?

The PR adds functions `arrays_overlap`, `array_repeat`, `map_entries` to SparkR.

## How was this patch tested?

Tests added into R/pkg/tests/fulltests/test_sparkSQL.R

## Examples
### arrays_overlap
```
df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)),
                           list(list(1L, 2L), list(3L, 4L)),
                           list(list(1L, NA), list(3L, 4L))))
collect(select(df, arrays_overlap(df[[1]], df[[2]])))
```
```
  arrays_overlap(_1, _2)
1                   TRUE
2                  FALSE
3                     NA
```
### array_repeat
```
df <- createDataFrame(list(list("a", 3L), list("b", 2L)))
collect(select(df, array_repeat(df[[1]], df[[2]])))
```
```
  array_repeat(_1, _2)
1              a, a, a
2                 b, b
```
```
collect(select(df, array_repeat(df[[1]], 2L)))
```
```
  array_repeat(_1, 2)
1                a, a
2                b, b
```
### map_entries
```
df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
collect(select(df, map_entries(df$map)))
```
```
  map_entries(map)
1       x, 1, y, 2
```

Author: Marek Novotny <[email protected]>

Closes apache#21434 from mn-mikke/SPARK-24331.
  • Loading branch information
mn-mikke authored and Felix Cheung committed May 30, 2018
1 parent f489388 commit a4be981
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 6 deletions.
3 changes: 3 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ exportMethods("%<=>%",
"array_max",
"array_min",
"array_position",
"array_repeat",
"array_sort",
"arrays_overlap",
"asc",
"ascii",
"asin",
Expand Down Expand Up @@ -302,6 +304,7 @@ exportMethods("%<=>%",
"lower",
"lpad",
"ltrim",
"map_entries",
"map_keys",
"map_values",
"max",
Expand Down
2 changes: 2 additions & 0 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -2297,6 +2297,8 @@ setMethod("rename",

setClassUnion("characterOrColumn", c("character", "Column"))

setClassUnion("numericOrColumn", c("numeric", "Column"))

#' Arrange Rows by Variables
#'
#' Sort a SparkDataFrame by the specified column(s).
Expand Down
58 changes: 53 additions & 5 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ NULL
#' the map or array of maps.
#' \item \code{from_json}: it is the column containing the JSON string.
#' }
#' @param y Column to compute on.
#' @param value A value to compute on.
#' \itemize{
#' \item \code{array_contains}: a value to be checked if contained in the column.
Expand All @@ -207,7 +208,7 @@ NULL
#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1)))
#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1)))
#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1)))
#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1)))
#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1)))
#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
#' head(tmp2)
Expand All @@ -216,11 +217,10 @@ NULL
#' head(select(tmp, sort_array(tmp$v1)))
#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))
#' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl))
#' head(select(tmp3, map_keys(tmp3$v3)))
#' head(select(tmp3, map_values(tmp3$v3)))
#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3)))
#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp))
#' head(select(tmp4, concat(tmp4$v4, tmp4$v5)))
#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp))
#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5)))
#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))}
NULL

Expand Down Expand Up @@ -3048,6 +3048,26 @@ setMethod("array_position",
column(jc)
})

#' @details
#' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times
#' given by \code{count}.
#'
#' @param count a Column or constant determining the number of repetitions.
#' @rdname column_collection_functions
#' @aliases array_repeat array_repeat,Column,numericOrColumn-method
#' @note array_repeat since 2.4.0
setMethod("array_repeat",
signature(x = "Column", count = "numericOrColumn"),
function(x, count) {
if (class(count) == "Column") {
count <- count@jc
} else {
count <- as.integer(count)
}
jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, count)
column(jc)
})

#' @details
#' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array
#' must be orderable. NA elements will be placed at the end of the returned array.
Expand All @@ -3062,6 +3082,21 @@ setMethod("array_sort",
column(jc)
})

#' @details
#' \code{arrays_overlap}: Returns true if the input arrays have at least one non-null element in
#' common. If not and both arrays are non-empty and any of them contains a null, it returns null.
#' It returns false otherwise.
#'
#' @rdname column_collection_functions
#' @aliases arrays_overlap arrays_overlap,Column-method
#' @note arrays_overlap since 2.4.0
setMethod("arrays_overlap",
signature(x = "Column", y = "Column"),
function(x, y) {
jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", x@jc, y@jc)
column(jc)
})

#' @details
#' \code{flatten}: Creates a single array from an array of arrays.
#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed.
Expand All @@ -3076,6 +3111,19 @@ setMethod("flatten",
column(jc)
})

#' @details
#' \code{map_entries}: Returns an unordered array of all entries in the given map.
#'
#' @rdname column_collection_functions
#' @aliases map_entries map_entries,Column-method
#' @note map_entries since 2.4.0
setMethod("map_entries",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc)
column(jc)
})

#' @details
#' \code{map_keys}: Returns an unordered array containing the keys of the map.
#'
Expand Down
12 changes: 12 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,18 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") })
#' @name NULL
setGeneric("array_position", function(x, value) { standardGeneric("array_position") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_sort", function(x) { standardGeneric("array_sort") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("ascii", function(x) { standardGeneric("ascii") })
Expand Down Expand Up @@ -1034,6 +1042,10 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") })
#' @name NULL
setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("map_entries", function(x) { standardGeneric("map_entries") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("map_keys", function(x) { standardGeneric("map_keys") })
Expand Down
22 changes: 21 additions & 1 deletion R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,21 @@ test_that("column functions", {
result <- collect(select(df2, reverse(df2[[1]])))[[1]]
expect_equal(result, "cba")

# Test array_repeat()
df <- createDataFrame(list(list("a", 3L), list("b", 2L)))
result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]]
expect_equal(result, list(list("a", "a", "a"), list("b", "b")))

result <- collect(select(df, array_repeat(df[[1]], 2L)))[[1]]
expect_equal(result, list(list("a", "a"), list("b", "b")))

# Test arrays_overlap()
df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)),
list(list(1L, 2L), list(3L, 4L)),
list(list(1L, NA), list(3L, 4L))))
result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]]
expect_equal(result, c(TRUE, FALSE, NA))

# Test array_sort() and sort_array()
df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L))))

Expand Down Expand Up @@ -1531,8 +1546,13 @@ test_that("column functions", {
result <- collect(select(df, flatten(df[[1]])))[[1]]
expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L)))

# Test map_keys(), map_values() and element_at()
# Test map_entries(), map_keys(), map_values() and element_at()
df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
result <- collect(select(df, map_entries(df$map)))[[1]]
expected_entries <- list(listToStruct(list(key = "x", value = 1)),
listToStruct(list(key = "y", value = 2)))
expect_equal(result, list(expected_entries))

result <- collect(select(df, map_keys(df$map)))[[1]]
expect_equal(result, list(list("x", "y")))

Expand Down

0 comments on commit a4be981

Please sign in to comment.