From ee6ea3c68694e35c36ad006a7762297800d1e463 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 28 Apr 2022 00:43:55 +0800 Subject: [PATCH] [SPARK-38997][SQL] DS V2 aggregate push-down supports group by expressions ### What changes were proposed in this pull request? Currently, Spark DS V2 aggregate push-down only supports group by column. But the SQL show below is very useful and common. ``` SELECT CASE WHEN 'SALARY' > 8000.00 AND 'SALARY' < 10000.00 THEN 'SALARY' ELSE 0.00 END AS key, SUM('SALARY') FROM "test"."employee" GROUP BY key ``` ### Why are the changes needed? Let DS V2 aggregate push-down supports group by expressions ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests Closes #36325 from beliefer/SPARK-38997. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../expressions/aggregate/Aggregation.java | 10 +- .../sql/execution/DataSourceScanExec.scala | 2 +- .../datasources/AggregatePushDownUtils.scala | 23 ++-- .../datasources/DataSourceStrategy.scala | 7 +- .../execution/datasources/orc/OrcUtils.scala | 2 +- .../datasources/parquet/ParquetUtils.scala | 2 +- .../v2/V2ScanRelationPushDown.scala | 23 +++- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 27 ++-- .../datasources/v2/orc/OrcScan.scala | 2 +- .../datasources/v2/parquet/ParquetScan.scala | 2 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 120 ++++++++++++++---- 11 files changed, 151 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java index cf7dbb2978dd7..11d9e475ca1bf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java @@ -20,7 +20,7 @@ import java.io.Serializable; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * Aggregation in SQL statement. @@ -30,14 +30,14 @@ @Evolving public final class Aggregation implements Serializable { private final AggregateFunc[] aggregateExpressions; - private final NamedReference[] groupByColumns; + private final Expression[] groupByExpressions; - public Aggregation(AggregateFunc[] aggregateExpressions, NamedReference[] groupByColumns) { + public Aggregation(AggregateFunc[] aggregateExpressions, Expression[] groupByExpressions) { this.aggregateExpressions = aggregateExpressions; - this.groupByColumns = groupByColumns; + this.groupByExpressions = groupByExpressions; } public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; } - public NamedReference[] groupByColumns() { return groupByColumns; } + public Expression[] groupByExpressions() { return groupByExpressions; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 953a7db0f9da1..9141a3f742e83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -161,7 +161,7 @@ case class RowDataSourceScanExec( "PushedFilters" -> pushedFilters) ++ pushedDownOperators.aggregation.fold(Map[String, String]()) { v => Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())), - "PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++ + "PushedGroupByExpressions" -> seqToString(v.groupByExpressions.map(_.describe())))} ++ topNOrLimitInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index 4779a3eaf2531..97ee3cd661b3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils @@ -93,8 +94,8 @@ object AggregatePushDownUtils { return None } - if (aggregation.groupByColumns.nonEmpty && - partitionNames.size != aggregation.groupByColumns.length) { + if (aggregation.groupByExpressions.nonEmpty && + partitionNames.size != aggregation.groupByExpressions.length) { // If there are group by columns, we only push down if the group by columns are the same as // the partition columns. In theory, if group by columns are a subset of partition columns, // we should still be able to push down. e.g. if table t has partition columns p1, p2, and p3, @@ -106,11 +107,11 @@ object AggregatePushDownUtils { // aggregate push down simple and don't handle this complicate case for now. return None } - aggregation.groupByColumns.foreach { col => + aggregation.groupByExpressions.map(extractColName).foreach { colName => // don't push down if the group by columns are not the same as the partition columns (orders // doesn't matter because reorder can be done at data source layer) - if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None - finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head)) + if (colName.isEmpty || !isPartitionCol(colName.get)) return None + finalSchema = finalSchema.add(getStructFieldForCol(colName.get)) } aggregation.aggregateExpressions.foreach { @@ -137,7 +138,8 @@ object AggregatePushDownUtils { def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { a.aggregateExpressions.sortBy(_.hashCode()) .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && - a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) + a.groupByExpressions.sortBy(_.hashCode()) + .sameElements(b.groupByExpressions.sortBy(_.hashCode())) } /** @@ -164,7 +166,7 @@ object AggregatePushDownUtils { def getSchemaWithoutGroupingExpression( aggSchema: StructType, aggregation: Aggregation): StructType = { - val numOfGroupByColumns = aggregation.groupByColumns.length + val numOfGroupByColumns = aggregation.groupByExpressions.length if (numOfGroupByColumns > 0) { new StructType(aggSchema.fields.drop(numOfGroupByColumns)) } else { @@ -179,7 +181,7 @@ object AggregatePushDownUtils { partitionSchema: StructType, aggregation: Aggregation, partitionValues: InternalRow): InternalRow = { - val groupByColNames = aggregation.groupByColumns.map(_.fieldNames.head) + val groupByColNames = aggregation.groupByExpressions.flatMap(extractColName) assert(groupByColNames.length == partitionSchema.length && groupByColNames.length == partitionValues.numFields, "The number of group by columns " + s"${groupByColNames.length} should be the same as partition schema length " + @@ -197,4 +199,9 @@ object AggregatePushDownUtils { partitionValues } } + + private def extractColName(v2Expr: V2Expression): Option[String] = v2Expr match { + case f: FieldReference if f.fieldNames.length == 1 => Some(f.fieldNames.head) + case _ => None + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1b14884e75994..e35d09320760c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -759,14 +759,13 @@ object DataSourceStrategy protected[sql] def translateAggregation( aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = { - def columnAsString(e: Expression): Option[FieldReference] = e match { - case PushableColumnWithoutNestedColumn(name) => - Some(FieldReference.column(name).asInstanceOf[FieldReference]) + def translateGroupBy(e: Expression): Option[V2Expression] = e match { + case PushableExpression(expr) => Some(expr) case _ => None } val translatedAggregates = aggregates.flatMap(translateAggregate) - val translatedGroupBys = groupBy.flatMap(columnAsString) + val translatedGroupBys = groupBy.flatMap(translateGroupBy) if (translatedAggregates.length != aggregates.length || translatedGroupBys.length != groupBy.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 1783aadaa7896..582a3a0156f7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -520,7 +520,7 @@ object OrcUtils extends Logging { val orcValuesDeserializer = new OrcDeserializer(schemaWithoutGroupBy, (0 until schemaWithoutGroupBy.length).toArray) val resultRow = orcValuesDeserializer.deserializeFromValues(aggORCValues) - if (aggregation.groupByColumns.nonEmpty) { + if (aggregation.groupByExpressions.nonEmpty) { val reOrderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol( partitionSchema, aggregation, partitionValues) new JoinedRow(reOrderedPartitionValues, resultRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 5a291e6a2e509..7c0348d58333c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -279,7 +279,7 @@ object ParquetUtils { throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i)) } - if (aggregation.groupByColumns.nonEmpty) { + if (aggregation.groupByExpressions.nonEmpty) { val reorderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol( partitionSchema, aggregation, partitionValues) new JoinedRow(reorderedPartitionValues, converter.currentRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index b7e0531989f42..89398fabdc314 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -183,9 +183,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit // scalastyle:on val newOutput = scan.readSchema().toAttributes assert(newOutput.length == groupingExpressions.length + finalAggregates.length) - val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { - case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) - case (_, b) => b + val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map { + case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId) + case ((expr, attr), ordinal) => + if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { + groupByExprToOutputOrdinal(expr.canonicalized) = ordinal + } + attr } val aggOutput = newOutput.drop(groupAttrs.length) val output = groupAttrs ++ aggOutput @@ -196,7 +201,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit |Pushed Aggregate Functions: | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} |Pushed Group by: - | ${pushedAggregates.get.groupByColumns.mkString(", ")} + | ${pushedAggregates.get.groupByExpressions.mkString(", ")} |Output: ${output.mkString(", ")} """.stripMargin) @@ -205,14 +210,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) if (r.supportCompletePushDown(pushedAggregates.get)) { val projectExpressions = finalResultExpressions.map { expr => - // TODO At present, only push down group by attribute is supported. - // In future, more attribute conversion is extended here. e.g. GetStructField - expr.transform { + expr.transformDown { case agg: AggregateExpression => val ordinal = aggExprToOutputOrdinal(agg.canonicalized) val child = addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + addCastIfNeeded(groupAttrs(ordinal), expr.dataType) } }.asInstanceOf[Seq[NamedExpression]] Project(projectExpressions, scanRelation) @@ -255,6 +261,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit case other => other } agg.copy(aggregateFunction = aggFunction) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + addCastIfNeeded(groupAttrs(ordinal), expr.dataType) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 0a1542a42956d..8b378d2d87c49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -20,7 +20,7 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} @@ -70,12 +70,15 @@ case class JDBCScanBuilder( private var pushedAggregateList: Array[String] = Array() - private var pushedGroupByCols: Option[Array[String]] = None + private var pushedGroupBys: Option[Array[String]] = None override def supportCompletePushDown(aggregation: Aggregation): Boolean = { - lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames() + lazy val fieldNames = aggregation.groupByExpressions()(0) match { + case field: FieldReference => field.fieldNames + case _ => Array.empty[String] + } jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) || - (aggregation.groupByColumns().length == 1 && fieldNames.length == 1 && + (aggregation.groupByExpressions().length == 1 && fieldNames.length == 1 && jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_))) } @@ -86,20 +89,18 @@ case class JDBCScanBuilder( val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate) if (compiledAggs.length != aggregation.aggregateExpressions.length) return false - val groupByCols = aggregation.groupByColumns.map { col => - if (col.fieldNames.length != 1) return false - dialect.quoteIdentifier(col.fieldNames.head) - } + val compiledGroupBys = aggregation.groupByExpressions.flatMap(dialect.compileExpression) + if (compiledGroupBys.length != aggregation.groupByExpressions.length) return false // The column names here are already quoted and can be used to build sql string directly. // e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") => // SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee" // GROUP BY "DEPT", "NAME" - val selectList = groupByCols ++ compiledAggs - val groupByClause = if (groupByCols.isEmpty) { + val selectList = compiledGroupBys ++ compiledAggs + val groupByClause = if (compiledGroupBys.isEmpty) { "" } else { - "GROUP BY " + groupByCols.mkString(",") + "GROUP BY " + compiledGroupBys.mkString(",") } val aggQuery = s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " + @@ -107,7 +108,7 @@ case class JDBCScanBuilder( try { finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect) pushedAggregateList = selectList - pushedGroupByCols = Some(groupByCols) + pushedGroupBys = Some(compiledGroupBys) true } catch { case NonFatal(e) => @@ -173,6 +174,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedPredicate, - pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders) + pushedAggregateList, pushedGroupBys, tableSample, pushedLimit, sortOrders) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index ad8857d98037c..ccb9ca9c6b3f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -87,7 +87,7 @@ case class OrcScan( lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { (seqToString(pushedAggregate.get.aggregateExpressions), - seqToString(pushedAggregate.get.groupByColumns)) + seqToString(pushedAggregate.get.groupByExpressions)) } else { ("[]", "[]") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index ae55159c07d3a..0457e8be71540 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -119,7 +119,7 @@ case class ParquetScan( lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { (seqToString(pushedAggregate.get.aggregateExpressions), - seqToString(pushedAggregate.get.groupByColumns)) + seqToString(pushedAggregate.get.groupByExpressions)) } else { ("[]", "[]") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 5cfa2f465a2be..74e226acb7a14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -187,7 +187,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkLimitRemoved(df4, false) checkPushedInfo(df4, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT], ") + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") checkAnswer(df4, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -279,7 +279,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkSortRemoved(df6, false) checkLimitRemoved(df6, false) checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," + - " PushedFilters: [], PushedGroupByColumns: [DEPT], ") + " PushedFilters: [], PushedGroupByExpressions: [DEPT], ") checkAnswer(df6, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -633,7 +633,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT], ") + "PushedGroupByExpressions: [DEPT], ") checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) } @@ -654,7 +654,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [MAX(ID), AVG(ID)], " + "PushedFilters: [ID IS NOT NULL, ID > 0], " + - "PushedGroupByColumns: [], ") + "PushedGroupByExpressions: [], ") checkAnswer(df, Seq(Row(2, 1.5))) } @@ -736,18 +736,84 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } test("scan with aggregate push-down: SUM with group by") { - val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") - checkAggregateRemoved(df) - checkPushedInfo(df, "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [], PushedGroupByColumns: [DEPT], ") - checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) + val df1 = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") + checkAggregateRemoved(df1) + checkPushedInfo(df1, "PushedAggregates: [SUM(SALARY)], " + + "PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + checkAnswer(df1, Seq(Row(19000), Row(22000), Row(12000))) + + val df2 = sql( + """ + |SELECT CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END as key, + | SUM(SALARY) FROM h2.test.employee GROUP BY key""".stripMargin) + checkAggregateRemoved(df2) + checkPushedInfo(df2, + """ + |PushedAggregates: [SUM(SALARY)], + |PushedFilters: [], + |PushedGroupByExpressions: + |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df2, Seq(Row(0, 44000), Row(9000, 9000))) + + val df3 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy(when(($"SALARY" > 8000).and($"SALARY" < 10000), $"SALARY").otherwise(0).as("key")) + .agg(sum($"SALARY")) + checkAggregateRemoved(df3, false) + checkPushedInfo(df3, + """ + |PushedAggregates: [SUM(SALARY)], + |PushedFilters: [], + |PushedGroupByExpressions: + |[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df3, Seq(Row(0, 44000), Row(9000, 9000))) + + val df4 = sql( + """ + |SELECT DEPT, CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END as key, + | SUM(SALARY) FROM h2.test.employee GROUP BY DEPT, key""".stripMargin) + checkAggregateRemoved(df4) + checkPushedInfo(df4, + """ + |PushedAggregates: [SUM(SALARY)], + |PushedFilters: [], + |PushedGroupByExpressions: + |[DEPT, CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df4, Seq(Row(1, 0, 10000), Row(1, 9000, 9000), Row(2, 0, 22000), Row(6, 0, 12000))) + + val df5 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy($"DEPT", + when(($"SALARY" > 8000).and($"SALARY" < 10000), $"SALARY").otherwise(0) + .as("key")) + .agg(sum($"SALARY")) + checkAggregateRemoved(df5, false) + checkPushedInfo(df5, + """ + |PushedAggregates: [SUM(SALARY)], + |PushedFilters: [], + |PushedGroupByExpressions: + |[DEPT, CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df5, Seq(Row(1, 0, 10000), Row(1, 9000, 9000), Row(2, 0, 22000), Row(6, 0, 12000))) } test("scan with aggregate push-down: DISTINCT SUM with group by") { val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT") checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [SUM(DISTINCT SALARY)], " + - "PushedFilters: [], PushedGroupByColumns: [DEPT]") + "PushedFilters: [], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } @@ -757,7 +823,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT, NAME]") checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), Row(10000, 1000), Row(12000, 1200))) } @@ -771,7 +837,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel assert(filters1.isEmpty) checkAggregateRemoved(df1) checkPushedInfo(df1, "PushedAggregates: [MAX(SALARY)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT, NAME]") checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000), Row("2#david", 10000), Row("6#jen", 12000))) @@ -783,7 +849,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel assert(filters2.isEmpty) checkAggregateRemoved(df2) checkPushedInfo(df2, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT, NAME]") checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), Row("2#david", 11300), Row("6#jen", 13200))) @@ -803,7 +869,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df, false) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200))) } @@ -813,7 +879,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .min("SALARY").as("total") checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [MIN(SALARY)], " + - "PushedFilters: [], PushedGroupByColumns: [DEPT]") + "PushedFilters: [], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000))) } @@ -828,7 +894,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(query, false)// filter over aggregate not pushed down checkAggregateRemoved(query) checkPushedInfo(query, "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000))) } @@ -860,7 +926,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } @@ -870,7 +936,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) } @@ -880,7 +946,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } @@ -890,7 +956,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkFiltersRemoved(df) checkAggregateRemoved(df) checkPushedInfo(df, "PushedAggregates: [CORR(BONUS, BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) } @@ -902,7 +968,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df2.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => val expectedPlanFragment = - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []" + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: []" checkKeywordsExistsInExplain(df2, expectedPlanFragment) relation.scan match { case v1: V1ScanWrapper => @@ -955,7 +1021,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "PushedAggregates: [COUNT(CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00)" + " THEN SALARY ELSE 0.00 END), COUNT(CAS..., " + "PushedFilters: [], " + - "PushedGroupByColumns: [DEPT], ") + "PushedGroupByExpressions: [DEPT], ") checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 0d, 0d, 2, 0d), Row(2, 2, 2, 2, 2, 10000d, 12000d, 10000d, 12000d, 0d, 0d, 3, 0d), Row(2, 2, 2, 2, 2, 10000d, 9000d, 10000d, 10000d, 9000d, 0d, 2, 0d))) @@ -969,7 +1035,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expectedPlanFragment = if (ansiMode) { "PushedAggregates: [SUM(2147483647 + DEPT)], " + "PushedFilters: [], " + - "PushedGroupByColumns: []" + "PushedGroupByExpressions: []" } else { "PushedFilters: []" } @@ -1118,7 +1184,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df) checkPushedInfo(df, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]") + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT]") checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) val df2 = spark.table("h2.test.employee") @@ -1128,7 +1194,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df2) checkPushedInfo(df2, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]") + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT]") checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) } @@ -1145,7 +1211,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df, false) checkPushedInfo(df, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]") + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [NAME]") checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) @@ -1161,7 +1227,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"total" > 1000) checkAggregateRemoved(df2, false) checkPushedInfo(df2, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]") + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [NAME]") checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) }