Skip to content

Commit

Permalink
[SPARK-38868][SQL] Don't propagate exceptions from filter predicate w…
Browse files Browse the repository at this point in the history
…hen optimizing outer joins

### What changes were proposed in this pull request?

Change `EliminateOuterJoin#canFilterOutNull` to return `false` when a `where` condition throws an exception.

### Why are the changes needed?

Consider this query:
```
select *
from (select id, id as b from range(0, 10)) l
left outer join (select id, id + 1 as c from range(0, 10)) r
on l.id = r.id
where assert_true(c > 0) is null;
```
The query should succeed, but instead fails with
```
java.lang.RuntimeException: '(c#1L > cast(0 as bigint))' is not true!
```
This happens even though there is no row where `c > 0` is false.

The `EliminateOuterJoin` rule checks if it can convert the outer join to a inner join based on the expression in the where clause, which in this case is
```
assert_true(c > 0) is null
```
`EliminateOuterJoin#canFilterOutNull` evaluates that expression with `c` set to `null` to see if the result is `null` or `false`. That rule doesn't expect the result to be a `RuntimeException`, but in this case it always is.

That is, the assertion is failing during optimization, not at run time.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New unit test.

Closes apache#36230 from bersprockets/outer_join_eval_assert_issue.

Authored-by: Bruce Robbins <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
bersprockets authored and cloud-fan committed Apr 25, 2022
1 parent 5046b8c commit e2930b8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.annotation.tailrec
import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
Expand Down Expand Up @@ -151,8 +152,17 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
val emptyRow = new GenericInternalRow(attributes.length)
val boundE = BindReferences.bindReference(e, attributes)
if (boundE.exists(_.isInstanceOf[Unevaluable])) return false
val v = boundE.eval(emptyRow)
v == null || v == false

// some expressions, like map(), may throw an exception when dealing with null values.
// therefore, we need to handle exceptions.
try {
val v = boundE.eval(emptyRow)
v == null || v == false
} catch {
case NonFatal(e) =>
// cannot filter out null if `where` expression throws an exception with null input
false
}
}

private def buildNewJoinType(filter: Filter, join: Join): JoinType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull}
import org.apache.spark.sql.catalyst.expressions.{Coalesce, If, IsNotNull, Literal, RaiseError}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String

class OuterJoinEliminationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
Expand Down Expand Up @@ -252,4 +254,18 @@ class OuterJoinEliminationSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}
}

test("SPARK-38868: exception thrown from filter predicate does not propagate") {
val x = testRelation.subquery(Symbol("x"))
val y = testRelation1.subquery(Symbol("y"))

val message = Literal(UTF8String.fromString("Bad value"), StringType)
val originalQuery =
x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr))
.where(If("y.d".attr > 0, true, RaiseError(message)).isNull)

val optimized = Optimize.execute(originalQuery.analyze)

comparePlans(optimized, originalQuery.analyze)
}
}

0 comments on commit e2930b8

Please sign in to comment.