Skip to content

Commit

Permalink
[SPARK-39040][SQL] Respect NaNvl in EquivalentExpressions for express…
Browse files Browse the repository at this point in the history
…ion elimination

### What changes were proposed in this pull request?

Respect NaNvl in EquivalentExpressions for expression elimination.

### Why are the changes needed?

For example the query will fail:
```sql
set spark.sql.ansi.enabled=true;
set spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConstantFolding;
SELECT nanvl(1, 1/0 + 1/0);
```
```sql
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 4.0 failed 1 times, most recent failure: Lost task 0.0 in stage 4.0 (TID 4) (10.221.98.68 executor driver): org.apache.spark.SparkArithmeticException: divide by zero. To return NULL instead, use 'try_divide'. If necessary set spark.sql.ansi.enabled to false (except for ANSI interval type) to bypass this error.
== SQL(line 1, position 17) ==
select nanvl(1 , 1/0 + 1/0)
                 ^^^    at org.apache.spark.sql.errors.QueryExecutionErrors$.divideByZeroError(QueryExecutionErrors.scala:151)
 ```
We should respect the ordering of conditional expression that always evaluate the predicate branch first, so the query above should not fail.

### Does this PR introduce _any_ user-facing change?

yes, bug fix

### How was this patch tested?

add test

Closes apache#36376 from ulysses-you/SPARK-39040.

Authored-by: ulysses-you <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
ulysses-you authored and cloud-fan committed Apr 29, 2022
1 parent 9b16579 commit f6b43f0
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,14 @@ class EquivalentExpressions {
// will always get accessed.
// 4. Coalesce: it's also a conditional expression, we should only recurse into the first
// children, because others may not get accessed.
// 5. NaNvl: it's a conditional expression, we can only guarantee the left child can be always
// accessed. And if we hit the left child, the right will not be accessed.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case i: If => i.predicate :: Nil
case c: CaseWhen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
case n: NaNvl => n.left :: Nil
case other => other.children
}

Expand Down Expand Up @@ -173,6 +176,7 @@ class EquivalentExpressions {
// If there is only one child, the first child is already covered by
// `childrenToRecurse` and we should exclude it here.
case c: Coalesce if c.children.length > 1 => Seq(c.children)
case n: NaNvl => Seq(n.children)
case _ => Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,20 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
// So if `p` is replaced by subexpression, the literal will be reused.
assert(code.value.toString == "((Decimal) references[0] /* literal */)")
}

test("SPARK-39040: Respect NaNvl in EquivalentExpressions for expression elimination") {
val add = Add(Literal(1), Literal(0))
val n1 = NaNvl(Literal(1.0d), Add(add, add))
val e1 = new EquivalentExpressions
e1.addExprTree(n1)
assert(e1.getCommonSubexpressions.isEmpty)

val n2 = NaNvl(add, add)
val e2 = new EquivalentExpressions
e2.addExprTree(n2)
assert(e2.getCommonSubexpressions.size == 1)
assert(e2.getCommonSubexpressions.head == add)
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- Tests for conditional functions
CREATE TABLE t USING PARQUET AS SELECT c1, c2 FROM VALUES(1, 0),(2, 1) AS t(c1, c2);

SELECT nanvl(c1, c1/c2 + c1/c2) FROM t;

DROP TABLE IF EXISTS t;
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 3


-- !query
CREATE TABLE t USING PARQUET AS SELECT c1, c2 FROM VALUES(1, 0),(2, 1) AS t(c1, c2)
-- !query schema
struct<>
-- !query output



-- !query
SELECT nanvl(c1, c1/c2 + c1/c2) FROM t
-- !query schema
struct<nanvl(c1, ((c1 / c2) + (c1 / c2))):double>
-- !query output
1.0
2.0


-- !query
DROP TABLE IF EXISTS t
-- !query schema
struct<>
-- !query output

0 comments on commit f6b43f0

Please sign in to comment.