Skip to content

Commit

Permalink
[SPARK-46743][SQL] Count bug after constant folding
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This covers a corner case in the COUNT bug handling. Right now it is split across two rules (PullupCorrelatedPredicates and RewriteCorrelatedScalarSubquery), where the first marks potential COUNT bug subqueries, and the second performs more accurate detection. Both of them rely on the fact that Aggregate remains at the top of the subquery, which is usually a safe assumption. However, when the subquery can be constant folded, the aggregate gets replaced with the project and the second part of COUNT bug handling falls through.

An example when it happens: https://issues.apache.org/jira/browse/SPARK-46743 -- involves a temp view, which gets inlined and allows us to constant fold the subquery. (Therefore, replacing the temp view with an actual table makes the query return correct results).

This PR makes sure that the Aggregate always remains on top of the subquery body until the RewriteCorrelatedScalarSubquery rule (we later still run constant folding, so the constant aggregates would be folded away at a later point).

### Why are the changes needed?

Correctness bug. See the reasoning above.

### Does this PR introduce _any_ user-facing change?

Incorrect results become fixed

### How was this patch tested?

Query test

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#45125 from agubichev/SPARK-46743_count.

Lead-authored-by: Andrey Gubichev <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
agubichev and cloud-fan committed Mar 25, 2024
1 parent 1b55fd3 commit b2f6474
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,44 @@ abstract class Optimizer(catalogManager: CatalogManager)
// Do not optimize DPP subquery, as it was created from optimized plan and we should not
// optimize it again, to save optimization time and avoid breaking broadcast/subquery reuse.
case d: DynamicPruningSubquery => d
case s @ ScalarSubquery(a @ Aggregate(group, _, child), _, _, _, _, mayHaveCountBug)
if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) &&
mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
// This is a subquery with an aggregate that may suffer from a COUNT bug.
// Detailed COUNT bug detection is done at a later stage (e.g. in
// RewriteCorrelatedScalarSubquery).
// Make sure that the output plan always has the same aggregate node
// (i.e., it is not being constant folded).
// Note that this does not limit optimization opportunities for the subquery: after
// decorrelation is done, the subquery's body becomes part of the main plan and all
// optimization rules are applied again.
val projectOverAggregateChild = Project(a.references.toSeq, child)
val optimizedPlan = Optimizer.this.execute(Subquery.fromExpression(
s.withNewPlan(projectOverAggregateChild)))
assert(optimizedPlan.isInstanceOf[Subquery])
val optimizedInput = optimizedPlan.asInstanceOf[Subquery].child

assert(optimizedInput.output.size == projectOverAggregateChild.output.size)
// We preserve the top aggregation, but its input has been optimized via
// Optimizer.execute().
// Make sure that the attributes still have IDs expected by the Aggregate node
// by inserting a project if necessary.
val needProject = projectOverAggregateChild.output.zip(optimizedInput.output).exists {
case (oldAttr, newAttr) => oldAttr.exprId != newAttr.exprId
}
if (needProject) {
val updatedProjectList = projectOverAggregateChild.output.zip(optimizedInput.output).map {
case (oldAttr, newAttr) => Alias(newAttr, newAttr.name)(exprId = oldAttr.exprId)
}
s.withNewPlan(a.withNewChildren(Seq(Project(updatedProjectList, optimizedInput))))
} else {
// Remove the top-level project if it is trivial. We do it to minimize plan changes.
optimizedInput match {
case Project(projectList, input) if projectList.forall(_.isInstanceOf[Attribute]) =>
s.withNewPlan(a.withNewChildren(Seq(input)))
case _ => s.withNewPlan(a.withNewChildren(Seq(optimizedInput)))
}
}
case s: SubqueryExpression =>
val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s))
// At this point we have an optimized subquery plan that we are going to attach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EX
*/
object RewriteWithExpression extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) {
plan.transformDownWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
val inputPlans = p.children.toArray
var newPlan: LogicalPlan = p.mapExpressions { expr =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3509,6 +3509,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG =
buildConf("spark.sql.optimizer.decorrelateSubqueryPreventConstantHoldingForCountBug.enabled")
.internal()
.doc("If enabled, prevents constant folding in subqueries that contain" +
" a COUNT-bug-susceptible Aggregate.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY =
buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,71 @@ Project [a#x, b#x, scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
+- LocalRelation [col1#x, col2#x]


-- !query
CREATE TEMPORARY VIEW null_view(a, b) AS SELECT CAST(null AS int), CAST(null as int)
-- !query analysis
CreateViewCommand `null_view`, [(a,None), (b,None)], SELECT CAST(null AS int), CAST(null as int), false, false, LocalTempView, true
+- Project [cast(null as int) AS CAST(NULL AS INT)#x, cast(null as int) AS CAST(NULL AS INT)#x]
+- OneRowRelation


-- !query
SELECT
(
SELECT
COUNT(null_view.a) AS result
FROM
null_view
WHERE
null_view.a = l.a
)
FROM
l
-- !query analysis
Project [scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
: +- Aggregate [count(a#x) AS result#xL]
: +- Filter (a#x = outer(a#x))
: +- SubqueryAlias null_view
: +- View (`null_view`, [a#x, b#x])
: +- Project [cast(CAST(NULL AS INT)#x as int) AS a#x, cast(CAST(NULL AS INT)#x as int) AS b#x]
: +- Project [cast(null as int) AS CAST(NULL AS INT)#x, cast(null as int) AS CAST(NULL AS INT)#x]
: +- OneRowRelation
+- SubqueryAlias l
+- View (`l`, [a#x, b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
SELECT
(
SELECT
COUNT(null_view.a) AS result
FROM
null_view
WHERE
null_view.a = l.a
having count(*) > -1
)
FROM
l
-- !query analysis
Project [scalar-subquery#x [a#x] AS scalarsubquery(a)#xL]
: +- Project [result#xL]
: +- Filter (count(1)#xL > cast(-1 as bigint))
: +- Aggregate [count(a#x) AS result#xL, count(1) AS count(1)#xL]
: +- Filter (a#x = outer(a#x))
: +- SubqueryAlias null_view
: +- View (`null_view`, [a#x, b#x])
: +- Project [cast(CAST(NULL AS INT)#x as int) AS a#x, cast(CAST(NULL AS INT)#x as int) AS b#x]
: +- Project [cast(null as int) AS CAST(NULL AS INT)#x, cast(null as int) AS CAST(NULL AS INT)#x]
: +- OneRowRelation
+- SubqueryAlias l
+- View (`l`, [a#x, b#x])
+- Project [cast(col1#x as int) AS a#x, cast(col2#x as decimal(2,1)) AS b#x]
+- LocalRelation [col1#x, col2#x]


-- !query
set spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled = true
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,37 @@ select *, (select count(*) from r where l.a = r.c having count(*) <= 1) from l;
select *, (select count(*) from r where l.a = r.c having count(*) >= 2) from l;


CREATE TEMPORARY VIEW null_view(a, b) AS SELECT CAST(null AS int), CAST(null as int);

-- SPARK-46743: count bug is still detected on top of the subquery that can be constant folded.
SELECT
(
SELECT
COUNT(null_view.a) AS result
FROM
null_view
WHERE
null_view.a = l.a
)
FROM
l;


-- Same as above but with a filter (HAVING) above the aggregate
SELECT
(
SELECT
COUNT(null_view.a) AS result
FROM
null_view
WHERE
null_view.a = l.a
having count(*) > -1
)
FROM
l;


set spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled = true;

-- With legacy behavior flag set, both cases evaluate to 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,65 @@ NULL 5.0 NULL
NULL NULL NULL


-- !query
CREATE TEMPORARY VIEW null_view(a, b) AS SELECT CAST(null AS int), CAST(null as int)
-- !query schema
struct<>
-- !query output



-- !query
SELECT
(
SELECT
COUNT(null_view.a) AS result
FROM
null_view
WHERE
null_view.a = l.a
)
FROM
l
-- !query schema
struct<scalarsubquery(a):bigint>
-- !query output
0
0
0
0
0
0
0
0


-- !query
SELECT
(
SELECT
COUNT(null_view.a) AS result
FROM
null_view
WHERE
null_view.a = l.a
having count(*) > -1
)
FROM
l
-- !query schema
struct<scalarsubquery(a):bigint>
-- !query output
0
0
0
0
0
0
0
0


-- !query
set spark.sql.optimizer.decorrelateSubqueryLegacyIncorrectCountHandling.enabled = true
-- !query schema
Expand Down

0 comments on commit b2f6474

Please sign in to comment.