Skip to content

Commit

Permalink
[SPARK-43011][SQL] array_insert should fail with 0 index
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Make `array_insert` fail when input index `pos` is zero.

### Why are the changes needed?
see apache#40563 (comment)

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
updated UT

Closes apache#40641 from zhengruifeng/sql_array_insert_fails_zero.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
zhengruifeng authored and MaxGekk committed Apr 4, 2023
1 parent 3bc66da commit 3e9574c
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 37 deletions.
12 changes: 6 additions & 6 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,6 @@
],
"sqlState" : "23505"
},
"ELEMENT_AT_BY_INDEX_ZERO" : {
"message" : [
"The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1)."
],
"sqlState" : "22003"
},
"EMPTY_JSON_FIELD_VALUE" : {
"message" : [
"Failed to parse an empty string for data type <dataType>."
Expand Down Expand Up @@ -915,6 +909,12 @@
],
"sqlState" : "42602"
},
"INVALID_INDEX_OF_ZERO" : {
"message" : [
"The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1)."
],
"sqlState" : "22003"
},
"INVALID_JSON_ROOT_FIELD" : {
"message" : [
"Cannot convert JSON root field to target Spark type."
Expand Down
2 changes: 1 addition & 1 deletion docs/sql-error-conditions-sqlstates.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Spark SQL uses the following `SQLSTATE` classes:
</tr>
<tr>
<td></td>
<td><a href="arithmetic-overflow-error-class.md">ARITHMETIC_OVERFLOW</a>, <a href="sql-error-conditions.html#cast_overflow">CAST_OVERFLOW</a>, <a href="sql-error-conditions.html#cast_overflow_in_table_insert">CAST_OVERFLOW_IN_TABLE_INSERT</a>, <a href="sql-error-conditions.html#decimal_precision_exceeds_max_precision">DECIMAL_PRECISION_EXCEEDS_MAX_PRECISION</a>, <a href="sql-error-conditions.html#element_at_by_index_zero">ELEMENT_AT_BY_INDEX_ZERO</a>, <a href="sql-error-conditions.html#incorrect_end_offset">INCORRECT_END_OFFSET</a>, <a href="sql-error-conditions.html#incorrect_ramp_up_rate">INCORRECT_RAMP_UP_RATE</a>, <a href="invalid-array-index-error-class.md">INVALID_ARRAY_INDEX</a>, <a href="invalid-array-index-in-element-at-error-class.md">INVALID_ARRAY_INDEX_IN_ELEMENT_AT</a>, <a href="sql-error-conditions.html#numeric_out_of_supported_range">NUMERIC_OUT_OF_SUPPORTED_RANGE</a>, <a href="sql-error-conditions.html#numeric_value_out_of_range">NUMERIC_VALUE_OUT_OF_RANGE</a>
<td><a href="arithmetic-overflow-error-class.md">ARITHMETIC_OVERFLOW</a>, <a href="sql-error-conditions.html#cast_overflow">CAST_OVERFLOW</a>, <a href="sql-error-conditions.html#cast_overflow_in_table_insert">CAST_OVERFLOW_IN_TABLE_INSERT</a>, <a href="sql-error-conditions.html#decimal_precision_exceeds_max_precision">DECIMAL_PRECISION_EXCEEDS_MAX_PRECISION</a>, <a href="sql-error-conditions.html#invalid_index_of_zero">INVALID_INDEX_OF_ZERO</a>, <a href="sql-error-conditions.html#incorrect_end_offset">INCORRECT_END_OFFSET</a>, <a href="sql-error-conditions.html#incorrect_ramp_up_rate">INCORRECT_RAMP_UP_RATE</a>, <a href="invalid-array-index-error-class.md">INVALID_ARRAY_INDEX</a>, <a href="invalid-array-index-in-element-at-error-class.md">INVALID_ARRAY_INDEX_IN_ELEMENT_AT</a>, <a href="sql-error-conditions.html#numeric_out_of_supported_range">NUMERIC_OUT_OF_SUPPORTED_RANGE</a>, <a href="sql-error-conditions.html#numeric_value_out_of_range">NUMERIC_VALUE_OUT_OF_RANGE</a>
</td>
</tr>
<tr>
Expand Down
12 changes: 6 additions & 6 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,6 @@ Duplicate map key `<key>` was found, please check the input data. If you want to

Found duplicate keys `<keyColumn>`.

### ELEMENT_AT_BY_INDEX_ZERO

[SQLSTATE: 22003](sql-error-conditions-sqlstates.html#class-22-data-exception)

The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).

### EMPTY_JSON_FIELD_VALUE

[SQLSTATE: 42604](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
Expand Down Expand Up @@ -573,6 +567,12 @@ The fraction of sec must be zero. Valid range is [0, 60]. If necessary set `<ans

The identifier `<ident>` is invalid. Please, consider quoting it with back-quotes as ``<ident>``.

### INVALID_INDEX_OF_ZERO

[SQLSTATE: 22003](sql-error-conditions-sqlstates.html#class-22-data-exception)

The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).

### INVALID_JSON_ROOT_FIELD

[SQLSTATE: 22032](sql-error-conditions-sqlstates.html#class-22-data-exception)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2489,7 +2489,7 @@ case class ElementAt(
}
} else {
val idx = if (index == 0) {
throw QueryExecutionErrors.elementAtByIndexZeroError(getContextOrNull())
throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
} else if (index > 0) {
index - 1
} else {
Expand Down Expand Up @@ -2544,7 +2544,7 @@ case class ElementAt(
| $indexOutOfBoundBranch
|} else {
| if ($index == 0) {
| throw QueryExecutionErrors.elementAtByIndexZeroError($errorContext);
| throw QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
| } else if ($index > 0) {
| $index--;
| } else {
Expand Down Expand Up @@ -4767,7 +4767,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
since = "3.4.0")
case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes with ComplexTypeMergingExpression
with QueryErrorsBase {
with QueryErrorsBase with SupportQueryContext {

override def inputTypes: Seq[AbstractDataType] = {
(srcArrayExpr.dataType, posExpr.dataType, itemExpr.dataType) match {
Expand Down Expand Up @@ -4820,8 +4820,11 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
}

override def nullSafeEval(arr: Any, pos: Any, item: Any): Any = {
val baseArr = arr.asInstanceOf[ArrayData]
var posInt = pos.asInstanceOf[Int]
if (posInt == 0) {
throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
}
val baseArr = arr.asInstanceOf[ArrayData]
val arrayElementType = dataType.asInstanceOf[ArrayType].elementType

val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > baseArr.numElements())
Expand Down Expand Up @@ -4895,13 +4898,18 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
values, elementType, resLength, s"$prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(values, elementType, arr,
adjustedAllocIdx, i, first.dataType.asInstanceOf[ArrayType].containsNull)
val errorContext = getContextOrNullCode(ctx)

s"""
|int $itemInsertionIndex = 0;
|int $resLength = 0;
|int $adjustedAllocIdx = 0;
|boolean $insertedItemIsNull = ${itemExpr.isNull};
|
|if ($pos == 0) {
| throw QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
|}
|
|if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
|
| $resLength = java.lang.Math.abs($pos) + 1;
Expand Down Expand Up @@ -5002,6 +5010,8 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
override protected def withNewChildrenInternal(
newSrcArrayExpr: Expression, newPosExpr: Expression, newItemExpr: Expression): ArrayInsert =
copy(srcArrayExpr = newSrcArrayExpr, posExpr = newPosExpr, itemExpr = newItemExpr)

override def initQueryContext(): Option[SQLQueryContext] = Some(origin.context)
}

@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1605,9 +1605,9 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
"prettyName" -> prettyName))
}

def elementAtByIndexZeroError(context: SQLQueryContext): RuntimeException = {
def invalidIndexOfZeroError(context: SQLQueryContext): RuntimeException = {
new SparkRuntimeException(
errorClass = "ELEMENT_AT_BY_INDEX_ZERO",
errorClass = "INVALID_INDEX_OF_ZERO",
cause = null,
messageParameters = Map.empty,
context = getQueryContext(context),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2356,7 +2356,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper

// index edge cases
checkEvaluation(ArrayInsert(a1, Literal(2), Literal(3)), Seq(1, 3, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(0), Literal(3)), Seq(3, 1, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(1), Literal(3)), Seq(3, 1, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(4), Literal(3)), Seq(1, 2, 4, 3))
checkEvaluation(ArrayInsert(a1, Literal(-2), Literal(3)), Seq(1, 3, 2, 4))
Expand Down
17 changes: 14 additions & 3 deletions sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ struct<>
-- !query output
org.apache.spark.SparkRuntimeException
{
"errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
"errorClass" : "INVALID_INDEX_OF_ZERO",
"sqlState" : "22003",
"queryContext" : [ {
"objectType" : "",
Expand Down Expand Up @@ -561,9 +561,20 @@ struct<array_insert(array(1, 2, 3), 3, 4):array<int>>
-- !query
select array_insert(array(2, 3, 4), 0, 1)
-- !query schema
struct<array_insert(array(2, 3, 4), 0, 1):array<int>>
struct<>
-- !query output
[1,2,3,4]
org.apache.spark.SparkRuntimeException
{
"errorClass" : "INVALID_INDEX_OF_ZERO",
"sqlState" : "22003",
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 41,
"fragment" : "array_insert(array(2, 3, 4), 0, 1)"
} ]
}


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ struct<>
-- !query output
org.apache.spark.SparkRuntimeException
{
"errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
"errorClass" : "INVALID_INDEX_OF_ZERO",
"sqlState" : "22003"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct<>
-- !query output
org.apache.spark.SparkRuntimeException
{
"errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
"errorClass" : "INVALID_INDEX_OF_ZERO",
"sqlState" : "22003"
}

Expand Down
17 changes: 14 additions & 3 deletions sql/core/src/test/resources/sql-tests/results/array.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ struct<>
-- !query output
org.apache.spark.SparkRuntimeException
{
"errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
"errorClass" : "INVALID_INDEX_OF_ZERO",
"sqlState" : "22003"
}

Expand Down Expand Up @@ -442,9 +442,20 @@ struct<array_insert(array(1, 2, 3), 3, 4):array<int>>
-- !query
select array_insert(array(2, 3, 4), 0, 1)
-- !query schema
struct<array_insert(array(2, 3, 4), 0, 1):array<int>>
struct<>
-- !query output
[1,2,3,4]
org.apache.spark.SparkRuntimeException
{
"errorClass" : "INVALID_INDEX_OF_ZERO",
"sqlState" : "22003",
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 41,
"fragment" : "array_insert(array(2, 3, 4), 0, 1)"
} ]
}


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ struct<>
-- !query output
org.apache.spark.SparkRuntimeException
{
"errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
"errorClass" : "INVALID_INDEX_OF_ZERO",
"sqlState" : "22003"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct<>
-- !query output
org.apache.spark.SparkRuntimeException
{
"errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
"errorClass" : "INVALID_INDEX_OF_ZERO",
"sqlState" : "22003"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3191,17 +3191,27 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
Seq(Row(Seq[Double](3.0, 3.0, 2.0, 5.0, 1.0, 2.0)))
)
checkAnswer(df4.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq(true, false, false))))
checkAnswer(df5.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq("d", "a", "b", "c"))))

val e1 = intercept[SparkException] {
df5.selectExpr("array_insert(a, b, c)").show()
}
assert(e1.getCause.isInstanceOf[SparkRuntimeException])
checkError(
exception = e1.getCause.asInstanceOf[SparkRuntimeException],
errorClass = "INVALID_INDEX_OF_ZERO",
parameters = Map.empty,
context = ExpectedContext(
fragment = "array_insert(a, b, c)",
start = 0,
stop = 20)
)

checkAnswer(df5.select(
array_insert(col("a"), lit(1), col("c"))),
Seq(Row(Seq("d", "a", "b", "c")))
)
// null checks
checkAnswer(df6.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq("a", null, "b", "c", "d"))))
checkAnswer(df5.select(
array_insert(col("a"), col("b"), lit(null).cast("string"))),
Seq(Row(Seq(null, "a", "b", "c")))
)
checkAnswer(df6.select(
array_insert(col("a"), col("b"), lit(null).cast("string"))),
Seq(Row(Seq("a", null, "b", "c", null)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest
stop = 41))
}

test("ELEMENT_AT_BY_INDEX_ZERO: element_at from array by index zero") {
test("INVALID_INDEX_OF_ZERO: element_at from array by index zero") {
checkError(
exception = intercept[SparkRuntimeException](
sql("select element_at(array(1, 2, 3, 4, 5), 0)").collect()
),
errorClass = "ELEMENT_AT_BY_INDEX_ZERO",
errorClass = "INVALID_INDEX_OF_ZERO",
parameters = Map.empty,
context = ExpectedContext(
fragment = "element_at(array(1, 2, 3, 4, 5), 0)",
Expand Down

0 comments on commit 3e9574c

Please sign in to comment.