Skip to content

Commit

Permalink
[SPARK-21308][SQL] Remove SQLConf parameters from the optimizer
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR removes SQLConf parameters from the optimizer rules

### How was this patch tested?
The existing test cases

Author: gatorsmile <[email protected]>

Closes apache#18533 from gatorsmile/rmSQLConfOptimizer.
  • Loading branch information
gatorsmile authored and cloud-fan committed Jul 6, 2017
1 parent ab866f1 commit 75b168f
Show file tree
Hide file tree
Showing 23 changed files with 137 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ import org.apache.spark.sql.internal.SQLConf
* We may have several join reorder algorithms in the future. This class is the entry of these
* algorithms, and chooses which one to use.
*/
case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper {
object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {

private def conf = SQLConf.get

def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.cboEnabled || !conf.joinReorderEnabled) {
plan
Expand Down Expand Up @@ -379,7 +382,7 @@ object JoinReorderDPFilters extends PredicateHelper {

if (conf.joinReorderDPStarFilter) {
// Compute the tables in a star-schema relationship.
val starJoin = StarSchemaDetection(conf).findStarJoins(items, conditions.toSeq)
val starJoin = StarSchemaDetection.findStarJoins(items, conditions.toSeq)
val nonStarJoin = items.filterNot(starJoin.contains(_))

if (starJoin.nonEmpty && nonStarJoin.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._
* Abstract class all optimizers should inherit of, contains the standard batches (extending
* Optimizers can override this.
*/
abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
abstract class Optimizer(sessionCatalog: SessionCatalog)
extends RuleExecutor[LogicalPlan] {

protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations)
protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations)

def batches: Seq[Batch] = {
Batch("Eliminate Distinct", Once, EliminateDistinct) ::
Expand Down Expand Up @@ -77,11 +77,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
Batch("Operator Optimizations", fixedPoint, Seq(
// Operator push down
PushProjectionThroughUnion,
ReorderJoin(conf),
ReorderJoin,
EliminateOuterJoin,
PushPredicateThroughJoin,
PushDownPredicate,
LimitPushDown(conf),
LimitPushDown,
ColumnPruning,
InferFiltersFromConstraints,
// Operator combine
Expand All @@ -92,10 +92,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
CombineLimits,
CombineUnions,
// Constant folding and strength reduction
NullPropagation(conf),
NullPropagation,
ConstantPropagation,
FoldablePropagation,
OptimizeIn(conf),
OptimizeIn,
ConstantFolding,
ReorderAssociativeOperator,
LikeSimplification,
Expand All @@ -117,19 +117,19 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
CombineConcats) ++
extendedOperatorOptimizationRules: _*) ::
Batch("Check Cartesian Products", Once,
CheckCartesianProducts(conf)) ::
CheckCartesianProducts) ::
Batch("Join Reorder", Once,
CostBasedJoinReorder(conf)) ::
CostBasedJoinReorder) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates(conf)) ::
DecimalAggregates) ::
Batch("Object Expressions Optimization", fixedPoint,
EliminateMapObjects,
CombineTypedFilters) ::
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation,
PropagateEmptyRelation) ::
Batch("OptimizeCodegen", Once,
OptimizeCodegen(conf)) ::
OptimizeCodegen) ::
Batch("RewriteSubquery", Once,
RewritePredicateSubquery,
CollapseProject) :: Nil
Expand Down Expand Up @@ -178,8 +178,7 @@ class SimpleTestOptimizer extends Optimizer(
new SessionCatalog(
new InMemoryCatalog,
EmptyFunctionRegistry,
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)),
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true))
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)))

/**
* Remove redundant aliases from a query plan. A redundant alias is an alias that does not change
Expand Down Expand Up @@ -288,7 +287,7 @@ object RemoveRedundantProject extends Rule[LogicalPlan] {
/**
* Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins.
*/
case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] {
object LimitPushDown extends Rule[LogicalPlan] {

private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = {
plan match {
Expand Down Expand Up @@ -1077,8 +1076,7 @@ object CombineLimits extends Rule[LogicalPlan] {
* the join between R and S is not a cartesian product and therefore should be allowed.
* The predicate R.r = S.s is not recognized as a join condition until the ReorderJoin rule.
*/
case class CheckCartesianProducts(conf: SQLConf)
extends Rule[LogicalPlan] with PredicateHelper {
object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper {
/**
* Check if a join is a cartesian product. Returns true if
* there are no join conditions involving references from both left and right.
Expand All @@ -1090,7 +1088,7 @@ case class CheckCartesianProducts(conf: SQLConf)
}

def apply(plan: LogicalPlan): LogicalPlan =
if (conf.crossJoinEnabled) {
if (SQLConf.get.crossJoinEnabled) {
plan
} else plan transform {
case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, condition)
Expand All @@ -1112,7 +1110,7 @@ case class CheckCartesianProducts(conf: SQLConf)
* This uses the same rules for increasing the precision and scale of the output as
* [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]].
*/
case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] {
object DecimalAggregates extends Rule[LogicalPlan] {
import Decimal.MAX_LONG_DIGITS

/** Maximum number of decimal digits representable precisely in a Double */
Expand All @@ -1130,7 +1128,7 @@ case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] {
we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone))
DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone))

case _ => we
}
Expand All @@ -1142,7 +1140,7 @@ case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] {
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone))
DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone))

case _ => ae
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ import org.apache.spark.sql.internal.SQLConf
/**
* Encapsulates star-schema detection logic.
*/
case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
object StarSchemaDetection extends PredicateHelper {

private def conf = SQLConf.get

/**
* Star schema consists of one or more fact tables referencing a number of dimension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,12 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
* 2. Replaces [[In (value, seq[Literal])]] with optimized version
* [[InSet (value, HashSet[Literal])]] which is much faster.
*/
case class OptimizeIn(conf: SQLConf) extends Rule[LogicalPlan] {
object OptimizeIn extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case expr @ In(v, list) if expr.inSetConvertible =>
val newList = ExpressionSet(list).toSeq
if (newList.size > conf.optimizerInSetConversionThreshold) {
if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) {
val hSet = newList.map(e => e.eval(EmptyRow))
InSet(v, HashSet() ++ hSet)
} else if (newList.size < list.size) {
Expand Down Expand Up @@ -414,7 +414,7 @@ object LikeSimplification extends Rule[LogicalPlan] {
* equivalent [[Literal]] values. This rule is more specific with
* Null value propagation from bottom to top of the expression tree.
*/
case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] {
object NullPropagation extends Rule[LogicalPlan] {
private def isNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => true
case _ => false
Expand All @@ -423,9 +423,9 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) =>
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
Cast(Literal(0L), e.dataType, Option(SQLConf.get.sessionLocalTimeZone))
case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) =>
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
Cast(Literal(0L), e.dataType, Option(SQLConf.get.sessionLocalTimeZone))
case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
ae.copy(aggregateFunction = Count(Literal(1)))
Expand Down Expand Up @@ -552,14 +552,14 @@ object FoldablePropagation extends Rule[LogicalPlan] {
/**
* Optimizes expressions by replacing according to CodeGen configuration.
*/
case class OptimizeCodegen(conf: SQLConf) extends Rule[LogicalPlan] {
object OptimizeCodegen extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e: CaseWhen if canCodegen(e) => e.toCodegen()
}

private def canCodegen(e: CaseWhen): Boolean = {
val numBranches = e.branches.size + e.elseValue.size
numBranches <= conf.maxCaseBranchesForCodegen
numBranches <= SQLConf.get.maxCaseBranchesForCodegen
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.internal.SQLConf
*
* If star schema detection is enabled, reorder the star join plans based on heuristics.
*/
case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper {
object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
/**
* Join a list of plans together and push down the conditions into them.
*
Expand Down Expand Up @@ -87,8 +87,8 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case ExtractFiltersAndInnerJoins(input, conditions)
if input.size > 2 && conditions.nonEmpty =>
if (conf.starSchemaDetection && !conf.cboEnabled) {
val starJoinPlan = StarSchemaDetection(conf).reorderStarJoins(input, conditions)
if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) {
val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions)
if (starJoinPlan.nonEmpty) {
val rest = input.filterNot(starJoinPlan.contains(_))
createOrderedJoin(starJoinPlan ++ rest, conditions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("Constant Folding", FixedPoint(50),
NullPropagation(conf),
NullPropagation,
ConstantFolding,
BooleanSimplification,
SimplifyBinaryComparison,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("Constant Folding", FixedPoint(50),
NullPropagation(conf),
NullPropagation,
ConstantFolding,
BooleanSimplification,
PruneFilters) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CombiningLimitsSuite extends PlanTest {
Batch("Combine Limit", FixedPoint(10),
CombineLimits) ::
Batch("Constant Folding", FixedPoint(10),
NullPropagation(conf),
NullPropagation,
ConstantFolding,
BooleanSimplification,
SimplifyConditionals) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest {
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("ConstantFolding", Once,
OptimizeIn(conf),
OptimizeIn,
ConstantFolding,
BooleanSimplification) :: Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DecimalAggregatesSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates(conf)) :: Nil
DecimalAggregates) :: Nil
}

val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class EliminateMapObjectsSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = {
Batch("EliminateMapObjects", FixedPoint(50),
NullPropagation(conf),
NullPropagation,
SimplifyCasts,
EliminateMapObjects) :: Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class JoinOptimizationSuite extends PlanTest {
CombineFilters,
PushDownPredicate,
BooleanSimplification,
ReorderJoin(conf),
ReorderJoin,
PushPredicateThroughJoin,
ColumnPruning,
CollapseProject) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,42 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED}


class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {

override val conf = new SQLConf().copy(CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true)

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Operator Optimizations", FixedPoint(100),
CombineFilters,
PushDownPredicate,
ReorderJoin(conf),
ReorderJoin,
PushPredicateThroughJoin,
ColumnPruning,
CollapseProject) ::
Batch("Join Reorder", Once,
CostBasedJoinReorder(conf)) :: Nil
CostBasedJoinReorder) :: Nil
}

var originalConfCBOEnabled = false
var originalConfJoinReorderEnabled = false

override def beforeAll(): Unit = {
super.beforeAll()
originalConfCBOEnabled = conf.cboEnabled
originalConfJoinReorderEnabled = conf.joinReorderEnabled
conf.setConf(CBO_ENABLED, true)
conf.setConf(JOIN_REORDER_ENABLED, true)
}

override def afterAll(): Unit = {
try {
conf.setConf(CBO_ENABLED, originalConfCBOEnabled)
conf.setConf(JOIN_REORDER_ENABLED, originalConfJoinReorderEnabled)
} finally {
super.afterAll()
}
}

/** Set up tables and columns for testing */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class LimitPushdownSuite extends PlanTest {
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
Batch("Limit pushdown", FixedPoint(100),
LimitPushDown(conf),
LimitPushDown,
CombineLimits,
ConstantFolding,
BooleanSimplification) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules._
class OptimizeCodegenSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(conf)) :: Nil
val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen) :: Nil
}

protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
Expand Down
Loading

0 comments on commit 75b168f

Please sign in to comment.