Skip to content

Commit

Permalink
[SPARK-20392][SQL] Set barrier to prevent re-entering a tree
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

It is reported that there is performance downgrade when applying ML pipeline for dataset with many columns but few rows.

A big part of the performance downgrade comes from some operations (e.g., `select`) on DataFrame/Dataset which re-create new DataFrame/Dataset with a new `LogicalPlan`. The cost can be ignored in the usage of SQL, normally.

However, it's not rare to chain dozens of pipeline stages in ML. When the query plan grows incrementally during running those stages, the total cost spent on re-creation of DataFrame grows too. In particular, the `Analyzer` will go through the big query plan even most part of it is analyzed.

By eliminating part of the cost, the time to run the example code locally is reduced from about 1min to about 30 secs.

In particular, the time applying the pipeline locally is mostly spent on calling transform of the 137 `Bucketizer`s. Before the change, each call of `Bucketizer`'s transform can cost about 0.4 sec. So the total time spent on all `Bucketizer`s' transform is about 50 secs. After the change, each call only costs about 0.1 sec.

<del>We also make `boundEnc` as lazy variable to reduce unnecessary running time.</del>

### Performance improvement

The codes and datasets provided by Barry Becker to re-produce this issue and benchmark can be found on the JIRA.

Before this patch: about 1 min
After this patch: about 20 secs

## How was this patch tested?

Existing tests.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Liang-Chi Hsieh <[email protected]>

Closes apache#17770 from viirya/SPARK-20392.
  • Loading branch information
viirya authored and cloud-fan committed May 26, 2017
1 parent f47700c commit 8ce0d8f
Show file tree
Hide file tree
Showing 16 changed files with 151 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,15 @@ class Analyzer(
Batch("Subquery", Once,
UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
CleanupAliases)
CleanupAliases,
EliminateBarriers)
)

/**
* Analyze cte definitions and substitute child plan with analyzed cte definitions.
*/
object CTESubstitution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case With(child, relations) =>
substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
case (resolved, (name, relation)) =>
Expand Down Expand Up @@ -201,7 +202,7 @@ class Analyzer(
* Substitute child plan with WindowSpecDefinitions.
*/
object WindowsSubstitution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
case WithWindowDefinition(windowDefinitions, child) =>
child.transform {
Expand Down Expand Up @@ -243,7 +244,7 @@ class Analyzer(
private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child)

Expand Down Expand Up @@ -615,7 +616,7 @@ class Analyzer(
case _ => plan
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
case v: View =>
Expand Down Expand Up @@ -670,7 +671,9 @@ class Analyzer(
* Generate a new logical plan for the right child with different expression IDs
* for all conflicting attributes.
*/
private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
private def dedupRight (left: LogicalPlan, oriRight: LogicalPlan): LogicalPlan = {
// Remove analysis barrier if any.
val right = EliminateBarriers(oriRight)
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
s"between $left and $right")
Expand Down Expand Up @@ -713,7 +716,7 @@ class Analyzer(
* that this rule cannot handle. When that is the case, there must be another rule
* that resolves these conflicts. Otherwise, the analysis will fail.
*/
right
oriRight
case Some((oldRelation, newRelation)) =>
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
val newRight = right transformUp {
Expand All @@ -726,7 +729,7 @@ class Analyzer(
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
}
}
newRight
AnalysisBarrier(newRight)
}
}

Expand Down Expand Up @@ -787,7 +790,7 @@ class Analyzer(
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p: LogicalPlan if !p.childrenResolved => p

// If the projection list contains Stars, expand it.
Expand Down Expand Up @@ -961,7 +964,7 @@ class Analyzer(
* have no effect on the results.
*/
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
Expand Down Expand Up @@ -1017,7 +1020,7 @@ class Analyzer(
}}
}

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case agg @ Aggregate(groups, aggs, child)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(!_.resolved) =>
Expand All @@ -1041,11 +1044,13 @@ class Analyzer(
* The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/
object ResolveMissingReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
case sa @ Sort(_, _, child: Aggregate) => sa

case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
case s @ Sort(order, _, orgChild) if !s.resolved && orgChild.resolved =>
val child = EliminateBarriers(orgChild)
try {
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
Expand All @@ -1066,7 +1071,8 @@ class Analyzer(
case ae: AnalysisException => s
}

case f @ Filter(cond, child) if !f.resolved && child.resolved =>
case f @ Filter(cond, orgChild) if !f.resolved && orgChild.resolved =>
val child = EliminateBarriers(orgChild)
try {
val newCond = resolveExpressionRecursively(cond, child)
val requiredAttrs = newCond.references.filter(_.resolved)
Expand All @@ -1093,7 +1099,7 @@ class Analyzer(
*/
private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
if (missingAttrs.isEmpty) {
return plan
return AnalysisBarrier(plan)
}
plan match {
case p: Project =>
Expand Down Expand Up @@ -1165,7 +1171,7 @@ class Analyzer(
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
*/
object ResolveFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
Expand Down Expand Up @@ -1504,7 +1510,7 @@ class Analyzer(
/**
* Resolve and rewrite all subqueries in an operator tree..
*/
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
// In case of HAVING (a filter after an aggregate) we use both the aggregate and
// its child for resolution.
case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
Expand All @@ -1519,7 +1525,7 @@ class Analyzer(
* Turns projections that contain aggregate expressions into aggregations.
*/
object GlobalAggregates extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case Project(projectList, child) if containsAggregates(projectList) =>
Aggregate(Nil, projectList, child)
}
Expand All @@ -1545,7 +1551,9 @@ class Analyzer(
* underlying aggregate operator and then projected away after the original operator.
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) =>
apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier)
case filter @ Filter(havingCondition,
aggregate @ Aggregate(grouping, originalAggExprs, child))
if aggregate.resolved =>
Expand Down Expand Up @@ -1605,6 +1613,8 @@ class Analyzer(
case ae: AnalysisException => filter
}

case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) =>
apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier)
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>

// Try resolving the ordering as though it is in the aggregate clause.
Expand Down Expand Up @@ -1717,7 +1727,7 @@ class Analyzer(
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get
throw new AnalysisException("Generators are not supported when it's nested in " +
Expand Down Expand Up @@ -1775,7 +1785,7 @@ class Analyzer(
* that wrap the [[Generator]].
*/
object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case g: Generate if !g.child.resolved || !g.generator.resolved => g
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
Expand Down Expand Up @@ -2092,7 +2102,7 @@ class Analyzer(
* put them into an inner Project and finally project them away at the outer Project.
*/
object PullOutNondeterministic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.resolved => p // Skip unresolved nodes.
case p: Project => p
case f: Filter => f
Expand Down Expand Up @@ -2137,7 +2147,7 @@ class Analyzer(
* and we should return null if the input is null.
*/
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.resolved => p // Skip unresolved nodes.

case p => p transformExpressionsUp {
Expand Down Expand Up @@ -2202,7 +2212,7 @@ class Analyzer(
* Then apply a Project on a normal Join to eliminate natural or using join.
*/
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
if left.resolved && right.resolved && j.duplicateResolved =>
commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
Expand Down Expand Up @@ -2267,7 +2277,7 @@ class Analyzer(
* to the given input attributes.
*/
object ResolveDeserializer extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.childrenResolved => p
case p if p.resolved => p

Expand Down Expand Up @@ -2353,7 +2363,7 @@ class Analyzer(
* constructed is an inner class.
*/
object ResolveNewInstance extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.childrenResolved => p
case p if p.resolved => p

Expand Down Expand Up @@ -2387,7 +2397,7 @@ class Analyzer(
"type of the field in the target object")
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p if !p.childrenResolved => p
case p if p.resolved => p

Expand Down Expand Up @@ -2441,7 +2451,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
case other => trimAliases(other)
}

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case Project(projectList, child) =>
val cleanedProjectList =
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Expand Down Expand Up @@ -2470,6 +2480,13 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
}

/** Remove the barrier nodes of analysis */
object EliminateBarriers extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case AnalysisBarrier(child) => child
}
}

/**
* Ignore event time watermark in batch query, which is only supported in Structured Streaming.
* TODO: add this rule into analyzer rule list.
Expand Down Expand Up @@ -2519,7 +2536,7 @@ object TimeWindowing extends Rule[LogicalPlan] {
* @return the logical plan that will generate the time windows using the Expand operator, with
* the Filter operator for correctness and Project for usability.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case p: LogicalPlan if p.children.size == 1 =>
val child = p.children.head
val windowExpressions =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] {
PromotePrecision(Cast(e, dataType))
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
// fix decimal precision for expressions
case q => q.transformExpressions(
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
})
)

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
case Some(tvf) =>
Expand Down
Loading

0 comments on commit 8ce0d8f

Please sign in to comment.