Skip to content

Commit

Permalink
[SPARK-33736][SQL] Handle MERGE in ReplaceNullWithFalseInPredicate
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR handles merge operations in `ReplaceNullWithFalseInPredicate`.

### Why are the changes needed?

These changes are needed to match what we already do for delete and update operations.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This PR extends existing tests to cover merge operations.

Closes apache#31579 from aokolnychyi/spark-33736.

Authored-by: Anton Okolnychyi <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
aokolnychyi authored and dongjoon-hyun committed Feb 18, 2021
1 parent 44a9aed commit 1ad3432
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan, UpdateTable}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -54,6 +54,11 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
case m @ MergeIntoTable(_, _, mergeCond, matchedActions, notMatchedActions) =>
m.copy(
mergeCondition = replaceNullWithFalse(mergeCond),
matchedActions = replaceNullWithFalse(matchedActions),
notMatchedActions = replaceNullWithFalse(notMatchedActions))
case p: LogicalPlan => p transformExpressions {
// For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no
// difference, as `null <=> true` and `false <=> true` both return false.
Expand Down Expand Up @@ -114,4 +119,13 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
e
}
}

private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = {
mergeActions.map {
case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond)))
case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond)))
case other => other
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, DeleteFromTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, IntegerType}
Expand All @@ -50,6 +50,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
testDelete(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
testUpdate(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
testMerge(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
}

test("Not expected type - replaceNullWithFalse") {
Expand All @@ -68,6 +69,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
}

test("replace nulls in nested expressions in branches of If") {
Expand All @@ -79,6 +81,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
}

test("replace null in elseValue of CaseWhen") {
Expand All @@ -91,6 +94,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond)
testDelete(originalCond, expectedCond)
testUpdate(originalCond, expectedCond)
testMerge(originalCond, expectedCond)
}

test("replace null in branch values of CaseWhen") {
Expand All @@ -102,6 +106,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
}

test("replace null in branches of If inside CaseWhen") {
Expand All @@ -120,6 +125,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond)
testDelete(originalCond, expectedCond)
testUpdate(originalCond, expectedCond)
testMerge(originalCond, expectedCond)
}

test("replace null in complex CaseWhen expressions") {
Expand All @@ -141,6 +147,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond)
testDelete(originalCond, expectedCond)
testUpdate(originalCond, expectedCond)
testMerge(originalCond, expectedCond)
}

test("replace null in Or") {
Expand All @@ -150,6 +157,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond)
testDelete(originalCond, expectedCond)
testUpdate(originalCond, expectedCond)
testMerge(originalCond, expectedCond)
}

test("replace null in And") {
Expand All @@ -158,6 +166,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
}

test("replace nulls in nested And/Or expressions") {
Expand All @@ -168,6 +177,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
}

test("replace null in And inside branches of If") {
Expand All @@ -179,6 +189,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
}

test("replace null in branches of If inside And") {
Expand All @@ -192,6 +203,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
}

test("replace null in branches of If inside another If") {
Expand All @@ -203,6 +215,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
}

test("replace null in CaseWhen inside another CaseWhen") {
Expand All @@ -212,6 +225,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond, expectedCond = FalseLiteral)
testDelete(originalCond, expectedCond = FalseLiteral)
testUpdate(originalCond, expectedCond = FalseLiteral)
testMerge(originalCond, expectedCond = FalseLiteral)
}

test("inability to replace null in non-boolean branches of If") {
Expand All @@ -226,6 +240,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond = condition, expectedCond = condition)
testDelete(originalCond = condition, expectedCond = condition)
testUpdate(originalCond = condition, expectedCond = condition)
testMerge(originalCond = condition, expectedCond = condition)
}

test("inability to replace null in non-boolean values of CaseWhen") {
Expand All @@ -244,6 +259,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
testUpdate(originalCond = condition, expectedCond = expectedCond)
testMerge(originalCond = condition, expectedCond = expectedCond)
}

test("inability to replace null in non-boolean branches of If inside another If") {
Expand All @@ -262,6 +278,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
testUpdate(originalCond = condition, expectedCond = expectedCond)
testMerge(originalCond = condition, expectedCond = expectedCond)
}

test("replace null in If used as a join condition") {
Expand Down Expand Up @@ -396,11 +413,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(allFalseCond, FalseLiteral)
testDelete(allFalseCond, FalseLiteral)
testUpdate(allFalseCond, FalseLiteral)
testMerge(allFalseCond, FalseLiteral)

testFilter(nonAllFalseCond, nonAllFalseCond)
testJoin(nonAllFalseCond, nonAllFalseCond)
testDelete(nonAllFalseCond, nonAllFalseCond)
testUpdate(nonAllFalseCond, nonAllFalseCond)
testMerge(nonAllFalseCond, nonAllFalseCond)
}

test("replace None of elseValue inside CaseWhen if all branches are null") {
Expand All @@ -412,6 +431,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testJoin(allFalseCond, FalseLiteral)
testDelete(allFalseCond, FalseLiteral)
testUpdate(allFalseCond, FalseLiteral)
testMerge(allFalseCond, FalseLiteral)
}

private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = {
Expand All @@ -434,6 +454,21 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
test((rel, expr) => UpdateTable(rel, Seq.empty, Some(expr)), originalCond, expectedCond)
}

private def testMerge(originalCond: Expression, expectedCond: Expression): Unit = {
val func = (rel: LogicalPlan, expr: Expression) => {
val assignments = Seq(
Assignment('i, 'i),
Assignment('b, 'b),
Assignment('a, 'a),
Assignment('m, 'm)
)
val matchedActions = UpdateAction(Some(expr), assignments) :: DeleteAction(Some(expr)) :: Nil
val notMatchedActions = InsertAction(Some(expr), assignments) :: Nil
MergeIntoTable(rel, rel, mergeCondition = expr, matchedActions, notMatchedActions)
}
test(func, originalCond, expectedCond)
}

private def testHigherOrderFunc(
argument: Expression,
createExpr: (Expression, Expression) => Expression,
Expand Down

0 comments on commit 1ad3432

Please sign in to comment.