Skip to content

Commit

Permalink
[SPARK-15370][SQL] Revert PR "Update RewriteCorrelatedSuquery rule"
Browse files Browse the repository at this point in the history
This reverts commit 9770f6e.

Author: Herman van Hovell <[email protected]>

Closes apache#13626 from hvanhovell/SPARK-15370-revert.
  • Loading branch information
hvanhovell committed Jun 12, 2016
1 parent e355460 commit 20b8f2c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 280 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,8 @@ trait PredicateHelper {
protected def replaceAlias(
condition: Expression,
aliases: AttributeMap[Expression]): Expression = {
// Use transformUp to prevent infinite recursion when the replacement expression
// redefines the same ExprId,
condition.transformUp {
case a: Attribute =>
aliases.getOrElse(a, a)
condition.transform {
case a: Attribute => aliases.getOrElse(a, a)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,7 @@ object CollapseProject extends Rule[LogicalPlan] {
// Substitute any attributes that are produced by the lower projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
// Use transformUp to prevent infinite recursion.
val rewrittenUpper = upper.map(_.transformUp {
val rewrittenUpper = upper.map(_.transform {
case a: Attribute => aliases.getOrElse(a, a)
})
// collapse upper and lower Projects may introduce unnecessary Aliases, trim them here.
Expand Down Expand Up @@ -1783,128 +1782,6 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
newExpression.asInstanceOf[E]
}

/**
* Statically evaluate an expression containing zero or more placeholders, given a set
* of bindings for placeholder values.
*/
private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = {
val rewrittenExpr = expr transform {
case r @ AttributeReference(_, dataType, _, _) =>
bindings(r.exprId) match {
case Some(v) => Literal.create(v, dataType)
case None => Literal.default(NullType)
}
}
Option(rewrittenExpr.eval())
}

/**
* Statically evaluate an expression containing one or more aggregates on an empty input.
*/
private def evalAggOnZeroTups(expr: Expression) : Option[Any] = {
// AggregateExpressions are Unevaluable, so we need to replace all aggregates
// in the expression with the value they would return for zero input tuples.
// Also replace attribute refs (for example, for grouping columns) with NULL.
val rewrittenExpr = expr transform {
case a @ AggregateExpression(aggFunc, _, _, resultId) =>
aggFunc.defaultResult.getOrElse(Literal.default(NullType))

case AttributeReference(_, _, _, _) => Literal.default(NullType)
}
Option(rewrittenExpr.eval())
}

/**
* Statically evaluate a scalar subquery on an empty input.
*
* <b>WARNING:</b> This method only covers subqueries that pass the checks under
* [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in
* CheckAnalysis become less restrictive, this method will need to change.
*/
private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
// Inputs to this method will start with a chain of zero or more SubqueryAlias
// and Project operators, followed by an optional Filter, followed by an
// Aggregate. Traverse the operators recursively.
def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = {
lp match {
case SubqueryAlias(_, child) => evalPlan(child)
case Filter(condition, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) bindings
else {
val exprResult = evalExpr(condition, bindings).getOrElse(false)
.asInstanceOf[Boolean]
if (exprResult) bindings else Map.empty
}

case Project(projectList, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) {
bindings
} else {
projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
}

case Aggregate(_, aggExprs, _) =>
// Some of the expressions under the Aggregate node are the join columns
// for joining with the outer query block. Fill those expressions in with
// nulls and statically evaluate the remainder.
aggExprs.map(ne => ne match {
case AttributeReference(_, _, _, _) => (ne.exprId, None)
case Alias(AttributeReference(_, _, _, _), _) => (ne.exprId, None)
case _ => (ne.exprId, evalAggOnZeroTups(ne))
}).toMap

case _ => sys.error(s"Unexpected operator in scalar subquery: $lp")
}
}

val resultMap = evalPlan(plan)

// By convention, the scalar subquery result is the leftmost field.
resultMap(plan.output.head.exprId)
}

/**
* Split the plan for a scalar subquery into the parts above the innermost query block
* (first part of returned value), the HAVING clause of the innermost query block
* (optional second part) and the parts below the HAVING CLAUSE (third part).
*/
private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = {
val topPart = ArrayBuffer.empty[LogicalPlan]
var bottomPart : LogicalPlan = plan
while (true) {
bottomPart match {
case havingPart@Filter(_, aggPart@Aggregate(_, _, _)) =>
return (topPart, Option(havingPart), aggPart.asInstanceOf[Aggregate])

case aggPart@Aggregate(_, _, _) =>
// No HAVING clause
return (topPart, None, aggPart)

case p@Project(_, child) =>
topPart += p
bottomPart = child

case s@SubqueryAlias(_, child) =>
topPart += s
bottomPart = child

case Filter(_, op@_) =>
sys.error(s"Correlated subquery has unexpected operator $op below filter")

case op@_ => sys.error(s"Unexpected operator $op in correlated subquery")
}
}

sys.error("This line should be unreachable")
}



// Name of generated column used in rewrite below
val ALWAYS_TRUE_COLNAME = "alwaysTrue"

/**
* Construct a new child plan by left joining the given subqueries to a base plan.
*/
Expand All @@ -1913,76 +1790,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
val origOutput = query.output.head

val resultWithZeroTups = evalSubqueryOnZeroTups(query)
if (resultWithZeroTups.isEmpty) {
// CASE 1: Subquery guaranteed not to have the COUNT bug
Project(
currentChild.output :+ origOutput,
Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
} else {
// Subquery might have the COUNT bug. Add appropriate corrections.
val (topPart, havingNode, aggNode) = splitSubquery(query)

// The next two cases add a leading column to the outer join input to make it
// possible to distinguish between the case when no tuples join and the case
// when the tuple that joins contains null values.
// The leading column always has the value TRUE.
val alwaysTrueExprId = NamedExpression.newExprId
val alwaysTrueExpr = Alias(Literal.TrueLiteral,
ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId)
val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
BooleanType)(exprId = alwaysTrueExprId)

val aggValRef = query.output.head

if (!havingNode.isDefined) {
// CASE 2: Subquery with no HAVING clause
Project(
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
Literal(resultWithZeroTups.get, origOutput.dataType),
aggValRef), origOutput.name)(exprId = origOutput.exprId),
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And)))

} else {
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
// Need to modify any operators below the join to pass through all columns
// referenced in the HAVING clause.
var subqueryRoot : UnaryNode = aggNode
val havingInputs : Seq[NamedExpression] = aggNode.output

topPart.reverse.foreach(
_ match {
case Project(projList, _) =>
subqueryRoot = Project(projList ++ havingInputs, subqueryRoot)
case s@SubqueryAlias(alias, _) => subqueryRoot = SubqueryAlias(alias, subqueryRoot)
case op@_ => sys.error(s"Unexpected operator $op in corelated subquery")
}
)

// CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(
Seq[(Expression, Expression)] (
(IsNull(alwaysTrueRef), Literal(resultWithZeroTups.get, origOutput.dataType)),
(Not(havingNode.get.condition), Literal(null, aggValRef.dataType))
), aggValRef
), origOutput.name) (exprId = origOutput.exprId)

Project(
currentChild.output :+ caseExpr,
Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And)))

}
}
Project(
currentChild.output :+ query.output.head,
Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
}
}

Expand Down
81 changes: 0 additions & 81 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -490,85 +490,4 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
""".stripMargin),
Row(3) :: Nil)
}

test("SPARK-15370: COUNT bug in WHERE clause (Filter)") {
// Case 1: Canonical example of the COUNT bug
checkAnswer(
sql("select l.a from l where (select count(*) from r where l.a = r.c) < l.a"),
Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
// Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses
// a rewrite that is vulnerable to the COUNT bug
checkAnswer(
sql("select l.a from l where (select count(*) from r where l.a = r.c) = 0"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
// Case 3: COUNT bug without a COUNT aggregate
checkAnswer(
sql("select l.a from l where (select sum(r.d) is null from r where l.a = r.c)"),
Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
}

test("SPARK-15370: COUNT bug in SELECT clause (Project)") {
checkAnswer(
sql("select a, (select count(*) from r where l.a = r.c) as cnt from l"),
Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0)
:: Row(null, 0) :: Row(6, 1) :: Nil)
}

test("SPARK-15370: COUNT bug in HAVING clause (Filter)") {
checkAnswer(
sql("select l.a as grp_a from l group by l.a " +
"having (select count(*) from r where grp_a = r.c) = 0 " +
"order by grp_a"),
Row(null) :: Row(1) :: Nil)
}

test("SPARK-15370: COUNT bug in Aggregate") {
checkAnswer(
sql("select l.a as aval, sum((select count(*) from r where l.a = r.c)) as cnt " +
"from l group by l.a order by aval"),
Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil)
}

test("SPARK-15370: COUNT bug negative examples") {
// Case 1: Potential COUNT bug case that was working correctly prior to the fix
checkAnswer(
sql("select l.a from l where (select sum(r.d) from r where l.a = r.c) is null"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil)
// Case 2: COUNT aggregate but no COUNT bug due to > 0 test.
checkAnswer(
sql("select l.a from l where (select count(*) from r where l.a = r.c) > 0"),
Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil)
// Case 3: COUNT inside aggregate expression but no COUNT bug.
checkAnswer(
sql("select l.a from l where (select count(*) + sum(r.d) from r where l.a = r.c) = 0"),
Nil)
}

test("SPARK-15370: COUNT bug in subquery in subquery in subquery") {
checkAnswer(
sql("""select l.a from l
|where (
| select cntPlusOne + 1 as cntPlusTwo from (
| select cnt + 1 as cntPlusOne from (
| select sum(r.c) s, count(*) cnt from r where l.a = r.c having cnt = 0
| )
| )
|) = 2""".stripMargin),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
}

test("SPARK-15370: COUNT bug with nasty predicate expr") {
checkAnswer(
sql("select l.a from l where " +
"(select case when count(*) = 1 then null else count(*) end as cnt " +
"from r where l.a = r.c) = 0"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
}

test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") {
checkAnswer(
sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"),
Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
}
}

0 comments on commit 20b8f2c

Please sign in to comment.