Skip to content

Commit

Permalink
[SPARK-38997][SQL] DS V2 aggregate push-down supports group by expres…
Browse files Browse the repository at this point in the history
…sions

### What changes were proposed in this pull request?
Currently, Spark DS V2 aggregate push-down only supports group by column.
But the SQL show below is very useful and common.
```
SELECT
  CASE
    WHEN 'SALARY' > 8000.00
      AND 'SALARY' < 10000.00
    THEN 'SALARY'
    ELSE 0.00
  END AS key,
  SUM('SALARY')
FROM "test"."employee"
GROUP BY key
```

### Why are the changes needed?
Let DS V2 aggregate push-down supports group by expressions

### Does this PR introduce _any_ user-facing change?
'No'.
New feature.

### How was this patch tested?
New tests

Closes apache#36325 from beliefer/SPARK-38997.

Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
beliefer authored and cloud-fan committed Apr 27, 2022
1 parent 852997d commit ee6ea3c
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.io.Serializable;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.Expression;

/**
* Aggregation in SQL statement.
Expand All @@ -30,14 +30,14 @@
@Evolving
public final class Aggregation implements Serializable {
private final AggregateFunc[] aggregateExpressions;
private final NamedReference[] groupByColumns;
private final Expression[] groupByExpressions;

public Aggregation(AggregateFunc[] aggregateExpressions, NamedReference[] groupByColumns) {
public Aggregation(AggregateFunc[] aggregateExpressions, Expression[] groupByExpressions) {
this.aggregateExpressions = aggregateExpressions;
this.groupByColumns = groupByColumns;
this.groupByExpressions = groupByExpressions;
}

public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; }

public NamedReference[] groupByColumns() { return groupByColumns; }
public Expression[] groupByExpressions() { return groupByExpressions; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ case class RowDataSourceScanExec(
"PushedFilters" -> pushedFilters) ++
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())),
"PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++
"PushedGroupByExpressions" -> seqToString(v.groupByExpressions.map(_.describe())))} ++
topNOrLimitInfo ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow}
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
import org.apache.spark.sql.execution.RowToColumnConverter
import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
Expand Down Expand Up @@ -93,8 +94,8 @@ object AggregatePushDownUtils {
return None
}

if (aggregation.groupByColumns.nonEmpty &&
partitionNames.size != aggregation.groupByColumns.length) {
if (aggregation.groupByExpressions.nonEmpty &&
partitionNames.size != aggregation.groupByExpressions.length) {
// If there are group by columns, we only push down if the group by columns are the same as
// the partition columns. In theory, if group by columns are a subset of partition columns,
// we should still be able to push down. e.g. if table t has partition columns p1, p2, and p3,
Expand All @@ -106,11 +107,11 @@ object AggregatePushDownUtils {
// aggregate push down simple and don't handle this complicate case for now.
return None
}
aggregation.groupByColumns.foreach { col =>
aggregation.groupByExpressions.map(extractColName).foreach { colName =>
// don't push down if the group by columns are not the same as the partition columns (orders
// doesn't matter because reorder can be done at data source layer)
if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None
finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head))
if (colName.isEmpty || !isPartitionCol(colName.get)) return None
finalSchema = finalSchema.add(getStructFieldForCol(colName.get))
}

aggregation.aggregateExpressions.foreach {
Expand All @@ -137,7 +138,8 @@ object AggregatePushDownUtils {
def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = {
a.aggregateExpressions.sortBy(_.hashCode())
.sameElements(b.aggregateExpressions.sortBy(_.hashCode())) &&
a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode()))
a.groupByExpressions.sortBy(_.hashCode())
.sameElements(b.groupByExpressions.sortBy(_.hashCode()))
}

/**
Expand All @@ -164,7 +166,7 @@ object AggregatePushDownUtils {
def getSchemaWithoutGroupingExpression(
aggSchema: StructType,
aggregation: Aggregation): StructType = {
val numOfGroupByColumns = aggregation.groupByColumns.length
val numOfGroupByColumns = aggregation.groupByExpressions.length
if (numOfGroupByColumns > 0) {
new StructType(aggSchema.fields.drop(numOfGroupByColumns))
} else {
Expand All @@ -179,7 +181,7 @@ object AggregatePushDownUtils {
partitionSchema: StructType,
aggregation: Aggregation,
partitionValues: InternalRow): InternalRow = {
val groupByColNames = aggregation.groupByColumns.map(_.fieldNames.head)
val groupByColNames = aggregation.groupByExpressions.flatMap(extractColName)
assert(groupByColNames.length == partitionSchema.length &&
groupByColNames.length == partitionValues.numFields, "The number of group by columns " +
s"${groupByColNames.length} should be the same as partition schema length " +
Expand All @@ -197,4 +199,9 @@ object AggregatePushDownUtils {
partitionValues
}
}

private def extractColName(v2Expr: V2Expression): Option[String] = v2Expr match {
case f: FieldReference if f.fieldNames.length == 1 => Some(f.fieldNames.head)
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -759,14 +759,13 @@ object DataSourceStrategy
protected[sql] def translateAggregation(
aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = {

def columnAsString(e: Expression): Option[FieldReference] = e match {
case PushableColumnWithoutNestedColumn(name) =>
Some(FieldReference.column(name).asInstanceOf[FieldReference])
def translateGroupBy(e: Expression): Option[V2Expression] = e match {
case PushableExpression(expr) => Some(expr)
case _ => None
}

val translatedAggregates = aggregates.flatMap(translateAggregate)
val translatedGroupBys = groupBy.flatMap(columnAsString)
val translatedGroupBys = groupBy.flatMap(translateGroupBy)

if (translatedAggregates.length != aggregates.length ||
translatedGroupBys.length != groupBy.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ object OrcUtils extends Logging {
val orcValuesDeserializer = new OrcDeserializer(schemaWithoutGroupBy,
(0 until schemaWithoutGroupBy.length).toArray)
val resultRow = orcValuesDeserializer.deserializeFromValues(aggORCValues)
if (aggregation.groupByColumns.nonEmpty) {
if (aggregation.groupByExpressions.nonEmpty) {
val reOrderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol(
partitionSchema, aggregation, partitionValues)
new JoinedRow(reOrderedPartitionValues, resultRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ object ParquetUtils {
throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i))
}

if (aggregation.groupByColumns.nonEmpty) {
if (aggregation.groupByExpressions.nonEmpty) {
val reorderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol(
partitionSchema, aggregation, partitionValues)
new JoinedRow(reorderedPartitionValues, converter.currentRecord)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
// scalastyle:on
val newOutput = scan.readSchema().toAttributes
assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
case (_, b) => b
val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map {
case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId)
case ((expr, attr), ordinal) =>
if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) {
groupByExprToOutputOrdinal(expr.canonicalized) = ordinal
}
attr
}
val aggOutput = newOutput.drop(groupAttrs.length)
val output = groupAttrs ++ aggOutput
Expand All @@ -196,7 +201,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
|Pushed Aggregate Functions:
| ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
|Pushed Group by:
| ${pushedAggregates.get.groupByColumns.mkString(", ")}
| ${pushedAggregates.get.groupByExpressions.mkString(", ")}
|Output: ${output.mkString(", ")}
""".stripMargin)

Expand All @@ -205,14 +210,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
if (r.supportCompletePushDown(pushedAggregates.get)) {
val projectExpressions = finalResultExpressions.map { expr =>
// TODO At present, only push down group by attribute is supported.
// In future, more attribute conversion is extended here. e.g. GetStructField
expr.transform {
expr.transformDown {
case agg: AggregateExpression =>
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
val child =
addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType)
Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId)
case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
}
}.asInstanceOf[Seq[NamedExpression]]
Project(projectExpressions, scanRelation)
Expand Down Expand Up @@ -255,6 +261,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
case other => other
}
agg.copy(aggregateFunction = aggFunction)
case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
Expand Down Expand Up @@ -70,12 +70,15 @@ case class JDBCScanBuilder(

private var pushedAggregateList: Array[String] = Array()

private var pushedGroupByCols: Option[Array[String]] = None
private var pushedGroupBys: Option[Array[String]] = None

override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames()
lazy val fieldNames = aggregation.groupByExpressions()(0) match {
case field: FieldReference => field.fieldNames
case _ => Array.empty[String]
}
jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) ||
(aggregation.groupByColumns().length == 1 && fieldNames.length == 1 &&
(aggregation.groupByExpressions().length == 1 && fieldNames.length == 1 &&
jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_)))
}

Expand All @@ -86,28 +89,26 @@ case class JDBCScanBuilder(
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate)
if (compiledAggs.length != aggregation.aggregateExpressions.length) return false

val groupByCols = aggregation.groupByColumns.map { col =>
if (col.fieldNames.length != 1) return false
dialect.quoteIdentifier(col.fieldNames.head)
}
val compiledGroupBys = aggregation.groupByExpressions.flatMap(dialect.compileExpression)
if (compiledGroupBys.length != aggregation.groupByExpressions.length) return false

// The column names here are already quoted and can be used to build sql string directly.
// e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
// SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
// GROUP BY "DEPT", "NAME"
val selectList = groupByCols ++ compiledAggs
val groupByClause = if (groupByCols.isEmpty) {
val selectList = compiledGroupBys ++ compiledAggs
val groupByClause = if (compiledGroupBys.isEmpty) {
""
} else {
"GROUP BY " + groupByCols.mkString(",")
"GROUP BY " + compiledGroupBys.mkString(",")
}

val aggQuery = s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " +
s"WHERE 1=0 $groupByClause"
try {
finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect)
pushedAggregateList = selectList
pushedGroupByCols = Some(groupByCols)
pushedGroupBys = Some(compiledGroupBys)
true
} catch {
case NonFatal(e) =>
Expand Down Expand Up @@ -173,6 +174,6 @@ case class JDBCScanBuilder(
// prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
// be used in sql string.
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedPredicate,
pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders)
pushedAggregateList, pushedGroupBys, tableSample, pushedLimit, sortOrders)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ case class OrcScan(

lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
(seqToString(pushedAggregate.get.aggregateExpressions),
seqToString(pushedAggregate.get.groupByColumns))
seqToString(pushedAggregate.get.groupByExpressions))
} else {
("[]", "[]")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ case class ParquetScan(

lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
(seqToString(pushedAggregate.get.aggregateExpressions),
seqToString(pushedAggregate.get.groupByColumns))
seqToString(pushedAggregate.get.groupByExpressions))
} else {
("[]", "[]")
}
Expand Down
Loading

0 comments on commit ee6ea3c

Please sign in to comment.