Skip to content

Commit

Permalink
[SPARK-48503][SQL] Allow grouping on expressions in scalar subqueries…
Browse files Browse the repository at this point in the history
…, if they are bound to outer rows

### What changes were proposed in this pull request?

Extends previous work in apache#46839, allowing the grouping expressions to be bound to outer references.

Most common example is
`select *, (select count(*) from T_inner where cast(T_inner.x as date) = T_outer.date group by cast(T_inner.x as date))`

Here, we group by cast(T_inner.x as date) which is bound to an outer row. This guarantees that for every outer row, there is exactly one value of cast(T_inner.x as date), so it is safe to group on it.
Previously, we required that only columns can be bound to outer expressions, thus forbidding such subqueries.

### Why are the changes needed?

Extends supported subqueries

### Does this PR introduce _any_ user-facing change?

Yes, previously failing queries are now passing

### How was this patch tested?

Query tests

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#47388 from agubichev/group_by_cols.

Authored-by: Andrey Gubichev <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
agubichev authored and cloud-fan committed Jul 26, 2024
1 parent 78b83fa commit f3b819e
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -905,26 +905,31 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
// are not part of the correlated columns.

// Note: groupByCols does not contain outer refs - grouping by an outer ref is always ok
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
// Collect the inner query attributes that are guaranteed to have a single value for each
// outer row. See comment on getCorrelatedEquivalentInnerColumns.
val correlatedEquivalentCols = getCorrelatedEquivalentInnerColumns(query)
val nonEquivalentGroupByCols = groupByCols -- correlatedEquivalentCols
// Collect the inner query expressions that are guaranteed to have a single value for each
// outer row. See comment on getCorrelatedEquivalentInnerExpressions.
val correlatedEquivalentExprs = getCorrelatedEquivalentInnerExpressions(query)
// Grouping expressions, except outer refs and constant expressions - grouping by an
// outer ref or a constant is always ok
val groupByExprs =
ExpressionSet(agg.groupingExpressions.filter(x => !x.isInstanceOf[OuterReference] &&
x.references.nonEmpty))
val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs

val invalidCols = if (!SQLConf.get.getConf(
SQLConf.LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE)) {
nonEquivalentGroupByCols
nonEquivalentGroupByExprs
} else {
// Legacy incorrect logic for checking for invalid group-by columns (see SPARK-48503).
// Allows any inner attribute that appears in a correlated predicate, even if it is a
// non-equality predicate or under an operator that can change the values of the attribute
// (see comments on getCorrelatedEquivalentInnerColumns for examples).
// Note: groupByCols does not contain outer refs - grouping by an outer ref is always ok
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references)
.filterNot(conditions.flatMap(_.references).contains)
val correlatedCols = AttributeSet(subqueryColumns)
val invalidColsLegacy = groupByCols -- correlatedCols
if (!nonEquivalentGroupByCols.isEmpty && invalidColsLegacy.isEmpty) {
if (!nonEquivalentGroupByExprs.isEmpty && invalidColsLegacy.isEmpty) {
logWarning(log"Using legacy behavior for " +
log"${MDC(LogKeys.CONFIG, SQLConf
.LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE.key)}. " +
Expand All @@ -936,10 +941,16 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
}

if (invalidCols.nonEmpty) {
val names = invalidCols.map { el =>
el match {
case attr: Attribute => attr.name
case expr: Expression => expr.toString
}
}
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"NON_CORRELATED_COLUMNS_IN_GROUP_BY",
messageParameters = Map("value" -> invalidCols.map(_.name).mkString(",")))
messageParameters = Map("value" -> names.mkString(",")))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.DecorrelateInnerQuery
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreePattern._
Expand Down Expand Up @@ -252,15 +251,50 @@ object SubExprUtils extends PredicateHelper {
}

/**
* Returns the inner query attributes that are guaranteed to have a single value for each
* outer row. Therefore, a scalar subquery is allowed to group-by on these attributes.
* Matches an equality 'expr = func(outer)', where 'func(outer)' depends on outer rows or
* is a constant.
* A scalar subquery is allowed to group-by on 'expr', as they are guaranteed to have exactly
* one value for every outer row.
* Positive examples:
* - x + 1 = outer(a)
* - cast(x as date) = outer(b)
* - y + z = 100
* - y / 10 = outer(b) + outer(c)
* In all of these examples, the left side of the equality will be returned.
*
* Negative examples:
* - x < outer(b)
* - x = y
* In all of these examples, None will be returned.
* @param expr
* @return
*/
private def getEquivalentToOuter(expr: Expression): Option[Expression] = {
val allowConstants =
SQLConf.get.getConf(SQLConf.SCALAR_SUBQUERY_ALLOW_GROUP_BY_COLUMN_EQUAL_TO_CONSTANT)

expr match {
case EqualTo(left, x)
if ((allowConstants || containsOuter(x)) &&
!x.exists(_.isInstanceOf[Attribute])) => Some(left)
case EqualTo(x, right)
if ((allowConstants || containsOuter(x)) &&
!x.exists(_.isInstanceOf[Attribute])) => Some(right)
case _ => None
}
}

/**
* Returns the inner query expressions that are guaranteed to have a single value for each
* outer row. Therefore, a scalar subquery is allowed to group-by on these expressions.
* We can derive these from correlated equality predicates, though we need to take care about
* propagating this through operators like OUTER JOIN or UNION.
*
* Positive examples:
* - x = outer(a) AND y = outer(b)
* - x = 1
* - x = outer(a) + 1
* - cast(x as date) = current_date() + outer(b)
*
* Negative examples:
* - x <= outer(a)
Expand All @@ -274,31 +308,31 @@ object SubExprUtils extends PredicateHelper {
* select *, (select count(*) from
* (select * from y where y1 = x1 union all select * from y) group by y1) from x;
*/
def getCorrelatedEquivalentInnerColumns(plan: LogicalPlan): AttributeSet = {
def getCorrelatedEquivalentInnerExpressions(plan: LogicalPlan): ExpressionSet = {
plan match {
case Filter(cond, child) =>
val correlated = AttributeSet(splitConjunctivePredicates(cond)
val equivalentExprs = ExpressionSet(splitConjunctivePredicates(cond)
.filter(
SQLConf.get.getConf(SQLConf.SCALAR_SUBQUERY_ALLOW_GROUP_BY_COLUMN_EQUAL_TO_CONSTANT)
|| containsOuter(_))
.filter(DecorrelateInnerQuery.canPullUpOverAgg)
.flatMap(_.references))
correlated ++ getCorrelatedEquivalentInnerColumns(child)
.flatMap(getEquivalentToOuter))
equivalentExprs ++ getCorrelatedEquivalentInnerExpressions(child)

case Join(left, right, joinType, _, _) =>
joinType match {
case _: InnerLike =>
AttributeSet(plan.children.flatMap(child => getCorrelatedEquivalentInnerColumns(child)))
case LeftOuter => getCorrelatedEquivalentInnerColumns(left)
case RightOuter => getCorrelatedEquivalentInnerColumns(right)
case FullOuter => AttributeSet.empty
case LeftSemi => getCorrelatedEquivalentInnerColumns(left)
case LeftAnti => getCorrelatedEquivalentInnerColumns(left)
case _ => AttributeSet.empty
ExpressionSet(plan.children.flatMap(
child => getCorrelatedEquivalentInnerExpressions(child)))
case LeftOuter => getCorrelatedEquivalentInnerExpressions(left)
case RightOuter => getCorrelatedEquivalentInnerExpressions(right)
case FullOuter => ExpressionSet().empty
case LeftSemi => getCorrelatedEquivalentInnerExpressions(left)
case LeftAnti => getCorrelatedEquivalentInnerExpressions(left)
case _ => ExpressionSet().empty
}

case _: Union => AttributeSet.empty
case Except(left, right, _) => getCorrelatedEquivalentInnerColumns(left)
case _: Union => ExpressionSet().empty
case Except(left, _, _) => getCorrelatedEquivalentInnerExpressions(left)

case
_: Aggregate |
Expand All @@ -318,9 +352,10 @@ object SubExprUtils extends PredicateHelper {
_: WithCTE |
_: Range |
_: SubqueryAlias =>
AttributeSet(plan.children.flatMap(child => getCorrelatedEquivalentInnerColumns(child)))
ExpressionSet(plan.children.flatMap(child =>
getCorrelatedEquivalentInnerExpressions(child)))

case _ => AttributeSet.empty
case _ => ExpressionSet().empty
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,39 @@ Project [x1#x, x2#x, scalar-subquery#x [x1#x && x1#x] AS scalarsubquery(x1, x1)#
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (select count(*) from y where x1 = y1 and cast(y2 as double) = x1 + 1
group by cast(y2 as double)) from x
-- !query analysis
Project [x1#x, x2#x, scalar-subquery#x [x1#x && x1#x] AS scalarsubquery(x1, x1)#xL]
: +- Aggregate [cast(y2#x as double)], [count(1) AS count(1)#xL]
: +- Filter ((outer(x1#x) = y1#x) AND (cast(y2#x as double) = cast((outer(x1#x) + 1) as double)))
: +- SubqueryAlias y
: +- View (`y`, [y1#x, y2#x])
: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias x
+- View (`x`, [x1#x, x2#x])
+- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select *, (select count(*) from y where y2 + 1 = x1 + x2 group by y2 + 1) from x
-- !query analysis
Project [x1#x, x2#x, scalar-subquery#x [x1#x && x2#x] AS scalarsubquery(x1, x2)#xL]
: +- Aggregate [(y2#x + 1)], [count(1) AS count(1)#xL]
: +- Filter ((y2#x + 1) = (outer(x1#x) + outer(x2#x)))
: +- SubqueryAlias y
: +- View (`y`, [y1#x, y2#x])
: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias x
+- View (`x`, [x1#x, x2#x])
+- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1
-- !query analysis
Expand Down Expand Up @@ -149,6 +182,26 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
}


-- !query
select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY",
"sqlState" : "0A000",
"messageParameters" : {
"value" : "y2"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 11,
"stopIndex" : 81,
"fragment" : "(select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2)"
} ]
}


-- !query
select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@ select * from x where (select count(*) from y where y1 > x1 group by x1) = 1;
select *, (select count(*) from y where x1 = y1 and y2 = 1 group by y2) from x;
-- Group-by column equal to expression with constants and outer refs - legal
select *, (select count(*) from y where x1 = y1 and y2 = x1 + 1 group by y2) from x;
-- Group-by expression is the same as the one we filter on - legal
select *, (select count(*) from y where x1 = y1 and cast(y2 as double) = x1 + 1
group by cast(y2 as double)) from x;
-- Group-by expression equal to an expression that depends on 2 outer refs -- legal
select *, (select count(*) from y where y2 + 1 = x1 + x2 group by y2 + 1) from x;


-- Illegal queries
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1;
select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x;
select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x;

-- Certain other operators like OUTER JOIN or UNION between the correlating filter and the group-by also can cause the scalar subquery to return multiple values and hence make the query illegal.
select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,25 @@ struct<x1:int,x2:int,scalarsubquery(x1, x1):bigint>
2 2 NULL


-- !query
select *, (select count(*) from y where x1 = y1 and cast(y2 as double) = x1 + 1
group by cast(y2 as double)) from x
-- !query schema
struct<x1:int,x2:int,scalarsubquery(x1, x1):bigint>
-- !query output
1 1 NULL
2 2 NULL


-- !query
select *, (select count(*) from y where y2 + 1 = x1 + x2 group by y2 + 1) from x
-- !query schema
struct<x1:int,x2:int,scalarsubquery(x1, x2):bigint>
-- !query output
1 1 NULL
2 2 NULL


-- !query
select * from x where (select count(*) from y where y1 > x1 group by y1) = 1
-- !query schema
Expand Down Expand Up @@ -137,6 +156,28 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
}


-- !query
select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY",
"sqlState" : "0A000",
"messageParameters" : {
"value" : "y2"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 11,
"stopIndex" : 81,
"fragment" : "(select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2)"
} ]
}


-- !query
select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x
-- !query schema
Expand Down

0 comments on commit f3b819e

Please sign in to comment.