Skip to content

Commit

Permalink
[SPARK-48416][SQL] Support nested correlated With expression
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

The inner `With` may reference common expressions of an outer `With`. This PR supports this case by making the rule `RewriteWithExpression` only rewrite top-level `With` expressions, and run the rule repeatedly so that the inner `With` expression becomes top-level `With` after one iteration, and gets rewritten in the next iteration.

### Why are the changes needed?

To support optimized filter pushdown with `With` expression.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

updated the unit test

### Was this patch authored or co-authored using generative AI tooling?

no

Closes apache#49093 from cloud-fan/with.

Lead-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
cloud-fan and cloud-fan committed Dec 12, 2024
1 parent f979bc8 commit df08177
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) ::
// We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression
// may produce `With` expressions that need to be rewritten.
Batch("Rewrite With expression", Once, RewriteWithExpression) ::
Batch("Rewrite With expression", fixedPoint, RewriteWithExpression) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,19 @@ object RewriteWithExpression extends Rule[LogicalPlan] {

private def rewriteWithExprAndInputPlans(
e: Expression,
inputPlans: Array[LogicalPlan]): Expression = {
inputPlans: Array[LogicalPlan],
isNestedWith: Boolean = false): Expression = {
if (!e.containsPattern(WITH_EXPRESSION)) return e
e match {
case w: With =>
// Do not handle nested With in one pass. Leave it to the next rule executor batch.
case w: With if !isNestedWith =>
// Rewrite nested With expressions first
val child = rewriteWithExprAndInputPlans(w.child, inputPlans)
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans))
val child = rewriteWithExprAndInputPlans(w.child, inputPlans, isNestedWith = true)
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith = true))
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
if (child.containsPattern(COMMON_EXPR_REF)) {
throw SparkException.internalError(
"Common expression definition cannot reference other Common expression definitions")
}
if (id.canonicalized) {
throw SparkException.internalError(
"Cannot rewrite canonicalized Common expression definitions")
Expand Down Expand Up @@ -148,10 +146,9 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
}

child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case ref: CommonExpressionRef =>
if (!refToExpr.contains(ref.id)) {
throw SparkException.internalError("Undefined common expression id " + ref.id)
}
// `child` may contain nested With and we only replace `CommonExpressionRef` that
// references common expressions in the current `With`.
case ref: CommonExpressionRef if refToExpr.contains(ref.id) =>
if (ref.id.canonicalized) {
throw SparkException.internalError(
"Cannot rewrite canonicalized Common expression references")
Expand All @@ -161,7 +158,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {

case c: ConditionalExpression =>
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
rewriteWithExprAndInputPlans(_, inputPlans))
rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith))
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
// Use transformUp to handle nested With.
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
Expand All @@ -174,7 +171,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
}
}

case other => other.mapChildren(rewriteWithExprAndInputPlans(_, inputPlans))
case other => other.mapChildren(rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.TempResolvedColumn
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
Expand All @@ -29,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
class RewriteWithExpressionSuite extends PlanTest {

object Optimizer extends RuleExecutor[LogicalPlan] {
val batches = Batch("Rewrite With expression", Once,
val batches = Batch("Rewrite With expression", FixedPoint(5),
PullOutGroupingExpressions,
RewriteWithExpression) :: Nil
}
Expand Down Expand Up @@ -84,13 +83,11 @@ class RewriteWithExpressionSuite extends PlanTest {
ref * ref
}

val plan = testRelation.select(outerExpr.as("col"))
comparePlans(
Optimizer.execute(plan),
Optimizer.execute(testRelation.select(outerExpr.as("col"))),
testRelation
.select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
.select((testRelation.output ++ Seq($"_common_expr_0",
($"_common_expr_0" + $"_common_expr_0" + b).as("_common_expr_1"))): _*)
.select(star(), (a + a).as("_common_expr_0"))
.select(a, b, ($"_common_expr_0" + $"_common_expr_0" + b).as("_common_expr_1"))
.select(($"_common_expr_1" * $"_common_expr_1").as("col"))
.analyze
)
Expand All @@ -104,42 +101,60 @@ class RewriteWithExpressionSuite extends PlanTest {
val outerExpr = With(b + b) { case Seq(ref) =>
ref * ref + innerExpr
}

val plan = testRelation.select(outerExpr.as("col"))
val rewrittenInnerExpr = (a + a).as("_common_expr_0")
val rewrittenOuterExpr = (b + b).as("_common_expr_1")
val finalExpr = rewrittenOuterExpr.toAttribute * rewrittenOuterExpr.toAttribute +
(rewrittenInnerExpr.toAttribute + rewrittenInnerExpr.toAttribute)
val finalExpr = $"_common_expr_1" * $"_common_expr_1" + ($"_common_expr_0" + $"_common_expr_0")
comparePlans(
Optimizer.execute(plan),
Optimizer.execute(testRelation.select(outerExpr.as("col"))),
testRelation
.select((testRelation.output :+ rewrittenInnerExpr): _*)
.select((testRelation.output :+ rewrittenInnerExpr.toAttribute :+ rewrittenOuterExpr): _*)
.select(star(), (b + b).as("_common_expr_1"))
.select(star(), (a + a).as("_common_expr_0"))
.select(finalExpr.as("col"))
.analyze
)
}

test("correlated nested WITH expression is not supported") {
test("correlated nested WITH expression is supported") {
val Seq(a, b) = testRelation.output
val outerCommonExprDef = CommonExpressionDef(b + b, CommonExpressionId(0))
val outerRef = new CommonExpressionRef(outerCommonExprDef)
val rewrittenOuterExpr = (b + b).as("_common_expr_0")

// The inner expression definition references the outer expression
val commonExprDef1 = CommonExpressionDef(a + a + outerRef, CommonExpressionId(1))
val ref1 = new CommonExpressionRef(commonExprDef1)
val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1))

val outerExpr1 = With(outerRef + innerExpr1, Seq(outerCommonExprDef))
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr1.as("col"))))
comparePlans(
Optimizer.execute(testRelation.select(outerExpr1.as("col"))),
testRelation
// The first Project contains the common expression of the outer With
.select(star(), rewrittenOuterExpr)
// The second Project contains the common expression of the inner With, which references
// the common expression of the outer With.
.select(star(), (a + a + $"_common_expr_0").as("_common_expr_1"))
// The final Project contains the final result expression, which references both common
// expressions.
.select(($"_common_expr_0" + ($"_common_expr_1" + $"_common_expr_1")).as("col"))
.analyze
)

val commonExprDef2 = CommonExpressionDef(a + a)
val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2))
val ref2 = new CommonExpressionRef(commonExprDef2)
// The inner main expression references the outer expression
val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef1))

val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr2.as("col"))))
comparePlans(
Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
testRelation
// The first Project contains the common expression of the outer With
.select(star(), rewrittenOuterExpr)
// The second Project contains the common expression of the inner With, which does not
// reference the common expression of the outer With.
.select(star(), (a + a).as("_common_expr_2"))
// The final Project contains the final result expression, which references both common
// expressions.
.select(($"_common_expr_0" + ($"_common_expr_2" + $"_common_expr_0")).as("col"))
.analyze
)
}

test("WITH expression in filter") {
Expand Down

0 comments on commit df08177

Please sign in to comment.