Skip to content

Commit

Permalink
[SPARK-35077][SQL] Migrate to transformWithPruning for leftover optim…
Browse files Browse the repository at this point in the history
…izer rules

### What changes were proposed in this pull request?

Migrate to transformWithPruning for the following queries:
- SimplifyExtractValueOps
- NormalizeFloatingNumbers
- PushProjectionThroughUnion
- PushDownPredicates
- ExtractPythonUDFFromAggregate
- ExtractPythonUDFFromJoinCondition
- ExtractGroupingPythonUDFFromAggregate
- ExtractPythonUDFs
- CleanupDynamicPruningFilters

</google-sheets-html-origin>

### Why are the changes needed?

Reduce the number of tree traversals and hence improve the query compilation latency.

### How was this patch tested?

Existing tests.
Performance diff:
<google-sheets-html-origin><style type="text/css"></style>
&nbsp; | Baseline | Experiment | Experiment/Baseline
-- | -- | -- | --
SimplifyExtractValueOps | 99367049 | 3679579 | 0.04
NormalizeFloatingNumbers | 24717928 | 20451094 | 0.83
PushProjectionThroughUnion | 14130245 | 7913551 | 0.56
PushDownPredicates | 276333542 | 261246842 | 0.95
ExtractPythonUDFFromAggregate | 6459451 | 2683556 | 0.42
ExtractPythonUDFFromJoinCondition | 5695404 | 2504573 | 0.44
ExtractGroupingPythonUDFFromAggregate | 5546701 | 1858755 | 0.34
ExtractPythonUDFs | 58726458 | 1598518 | 0.03
CleanupDynamicPruningFilters | 26606652 | 15417936 | 0.58
OptimizeSubqueries | 3072287940 | 2876462708 | 0.94

</google-sheets-html-origin>

Closes apache#32721 from sigmod/pushdown.

Authored-by: Yingyi Bu <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
  • Loading branch information
sigmod authored and gengliangwang committed Jun 2, 2021
1 parent c2de0a6 commit 3f6322f
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._

/**
* Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions.
*/
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p => p.transformExpressionsUp {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(EXTRACT_VALUE), ruleId) {
case p => p.transformExpressionsUpWithPruning(_.containsPattern(EXTRACT_VALUE), ruleId) {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -56,7 +57,7 @@ import org.apache.spark.sql.types._
object NormalizeFloatingNumbers extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan match {
case _ => plan transform {
case _ => plan.transformWithPruning( _.containsAnyPattern(WINDOW, JOIN)) {
case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
// Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need
// to normalize the `windowExpressions`, as they are executed per input row and should take
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
*/
object OptimizeSubqueries extends Rule[LogicalPlan] {
private def removeTopLevelSort(plan: LogicalPlan): LogicalPlan = {
if (!plan.containsPattern(SORT)) {
return plan
}
plan match {
case Sort(_, _, child) => child
case Project(fields, child) => Project(fields, removeTopLevelSort(child))
Expand Down Expand Up @@ -683,7 +686,8 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
result.asInstanceOf[A]
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAllPatterns(UNION, PROJECT)) {

// Push down deterministic projection through UNION ALL
case p @ Project(projectList, u: Union) =>
Expand Down Expand Up @@ -1283,7 +1287,8 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
* Filter-Join-Join-Join. Most predicates can be pushed down in a single pass.
*/
object PushDownPredicates extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(FILTER, JOIN)) {
CombineFilters.applyLocally
.orElse(PushPredicateThroughNonJoin.applyLocally)
.orElse(PushPredicateThroughJoin.applyLocally)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern.{INNER_LIKE_JOIN, OUTER_JOIN}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -203,7 +203,8 @@ object ExtractPythonUDFFromJoinCondition extends Rule[LogicalPlan] with Predicat
}.isDefined
}

override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAllPatterns(PYTHON_UDF, JOIN)) {
case j @ Join(_, _, joinType, Some(cond), _) if hasUnevaluablePythonUDF(cond, j) =>
if (!joinType.isInstanceOf[InnerLike]) {
// The current strategy supports only InnerLike join because for other types,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyConditionalsInPredicate" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps" ::
"org.apache.spark.sql.catalyst.optimizer.TransposeWindow" ::
"org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison" :: Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}

/**
Expand All @@ -36,12 +37,15 @@ object CleanupDynamicPruningFilters extends Rule[LogicalPlan] with PredicateHelp
return plan
}

plan.transform {
plan.transformWithPruning(
// No-op for trees that do not contain dynamic pruning.
_.containsAnyPattern(DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY)) {
// pass through anything that is pushed down into PhysicalOperation
case p @ PhysicalOperation(_, _, LogicalRelation(_: HadoopFsRelation, _, _, _)) => p
// remove any Filters with DynamicPruning that didn't get pushed down to PhysicalOperation.
case f @ Filter(condition, _) =>
val newCondition = condition.transform {
val newCondition = condition.transformWithPruning(
_.containsAnyPattern(DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY)) {
case _: DynamicPruning => TrueLiteral
}
f.copy(condition = newCondition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.errors.QueryCompilationErrors


Expand Down Expand Up @@ -75,7 +76,8 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
Project(projList.toSeq, agg.copy(aggregateExpressions = aggExpr.toSeq))
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAllPatterns(PYTHON_UDF, AGGREGATE)) {
case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
extract(agg)
}
Expand Down Expand Up @@ -139,7 +141,8 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] {
child = Project((projList ++ agg.child.output).toSeq, agg.child))
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAllPatterns(PYTHON_UDF, AGGREGATE)) {
case agg: Aggregate if agg.groupingExpressions.exists(hasScalarPythonUDF(_)) =>
extract(agg)
}
Expand Down Expand Up @@ -207,7 +210,10 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
// eventually. Here we skip subquery, as Python UDF only needs to be extracted once.
case s: Subquery if s.correlated => plan

case _ => plan transformUp {
case _ => plan.transformUpWithPruning(
// All cases must contain pattern PYTHON_UDF. PythonUDFs are member fields of BatchEvalPython
// and ArrowEvalPython.
_.containsPattern(PYTHON_UDF)) {
// A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and
// `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't
// extract Python UDFs from them.
Expand Down

0 comments on commit 3f6322f

Please sign in to comment.