Skip to content

Commit

Permalink
[SPARK-17949][SQL] A JVM object based aggregate operator
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR adds a new hash-based aggregate operator named `ObjectHashAggregateExec` that supports `TypedImperativeAggregate`, which may use arbitrary Java objects as aggregation states. Please refer to the [design doc](https://issues.apache.org/jira/secure/attachment/12834260/%5BDesign%20Doc%5D%20Support%20for%20Arbitrary%20Aggregation%20States.pdf) attached in [SPARK-17949](https://issues.apache.org/jira/browse/SPARK-17949) for more details about it.

The major benefit of this operator is better performance when evaluating `TypedImperativeAggregate` functions, especially when there are relatively few distinct groups. Functions like Hive UDAFs, `collect_list`, and `collect_set` may also benefit from this after being migrated to `TypedImperativeAggregate`.

The following feature flag is introduced to enable or disable the new aggregate operator:
- Name: `spark.sql.execution.useObjectHashAggregateExec`
- Default value: `true`

We can also configure the fallback threshold using the following SQL operation:
- Name: `spark.sql.objectHashAggregate.sortBased.fallbackThreshold`
- Default value: 128

  Fallback to sort-based aggregation when more than 128 distinct groups are accumulated in the aggregation hash map. This number is intentionally made small to avoid GC problems since aggregation buffers of this operator may contain arbitrary Java objects.

  This may be improved by implementing size tracking for this operator, but that can be done in a separate PR.

Code generation and size tracking are planned to be implemented in follow-up PRs.
## Benchmark results
### `ObjectHashAggregateExec` vs `SortAggregateExec`

The first benchmark compares `ObjectHashAggregateExec` and `SortAggregateExec` by evaluating `typed_count`, a testing `TypedImperativeAggregate` version of the SQL `count` function.

```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz

object agg v.s. sort agg:                Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
sort agg w/ group by                        31251 / 31908          3.4         298.0       1.0X
object agg w/ group by w/o fallback           6903 / 7141         15.2          65.8       4.5X
object agg w/ group by w/ fallback          20945 / 21613          5.0         199.7       1.5X
sort agg w/o group by                         4734 / 5463         22.1          45.2       6.6X
object agg w/o group by w/o fallback          4310 / 4529         24.3          41.1       7.3X
```

The next benchmark compares `ObjectHashAggregateExec` and `SortAggregateExec` by evaluating the Spark native version of `percentile_approx`.

Note that `percentile_approx` is so heavy an aggregate function that the bottleneck of the benchmark is evaluating the aggregate function itself rather than the aggregate operator since I couldn't run a large scale benchmark on my laptop. That's why the results are so close and looks counter-intuitive (aggregation with grouping is even faster than that aggregation without grouping).

```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz

object agg v.s. sort agg:                Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
sort agg w/ group by                          3418 / 3530          0.6        1630.0       1.0X
object agg w/ group by w/o fallback           3210 / 3314          0.7        1530.7       1.1X
object agg w/ group by w/ fallback            3419 / 3511          0.6        1630.1       1.0X
sort agg w/o group by                         4336 / 4499          0.5        2067.3       0.8X
object agg w/o group by w/o fallback          4271 / 4372          0.5        2036.7       0.8X
```
### Hive UDAF vs Spark AF

This benchmark compares the following two kinds of aggregate functions:
- "hive udaf": Hive implementation of `percentile_approx`, without partial aggregation supports, evaluated using `SortAggregateExec`.
- "spark af": Spark native implementation of `percentile_approx`, with partial aggregation support, evaluated using `ObjectHashAggregateExec`

The performance differences are mostly due to faster implementation and partial aggregation support in the Spark native version of `percentile_approx`.

This benchmark basically shows the performance differences between the worst case, where an aggregate function without partial aggregation support is evaluated using `SortAggregateExec`, and the best case, where a `TypedImperativeAggregate` with partial aggregation support is evaluated using `ObjectHashAggregateExec`.

```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz

hive udaf vs spark af:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
hive udaf w/o group by                        5326 / 5408          0.0       81264.2       1.0X
spark af w/o group by                           93 /  111          0.7        1415.6      57.4X
hive udaf w/ group by                         3804 / 3946          0.0       58050.1       1.4X
spark af w/ group by w/o fallback               71 /   90          0.9        1085.7      74.8X
spark af w/ group by w/ fallback                98 /  111          0.7        1501.6      54.1X
```
### Real world benchmark

We also did a relatively large benchmark using a real world query involving `percentile_approx`:
- Hive UDAF implementation, sort-based aggregation, w/o partial aggregation support

  24.77 minutes
- Native implementation, sort-based aggregation, w/ partial aggregation support

  4.64 minutes
- Native implementation, object hash aggregator, w/ partial aggregation support

  1.80 minutes
## How was this patch tested?

New unit tests and randomized test cases are added in `ObjectAggregateFunctionSuite`.

Author: Cheng Lian <[email protected]>

Closes apache#15590 from liancheng/obj-hash-agg.
  • Loading branch information
liancheng authored and yhuai committed Nov 3, 2016
1 parent 66a99f4 commit 27daf6b
Show file tree
Hide file tree
Showing 10 changed files with 1,527 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}
import org.apache.spark.sql.internal.SQLConf

/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
Expand Down Expand Up @@ -66,14 +67,28 @@ object AggUtils {
resultExpressions = resultExpressions,
child = child)
} else {
SortAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
child = child)
val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation
val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)

if (objectHashEnabled && useObjectHash) {
ObjectHashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
child = child)
} else {
SortAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
child = child)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering}
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter

class ObjectAggregationIterator(
outputAttributes: Seq[Attribute],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
originalInputAttributes: Seq[Attribute],
inputRows: Iterator[InternalRow],
fallbackCountThreshold: Int)
extends AggregationIterator(
groupingExpressions,
originalInputAttributes,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection) with Logging {

// Indicates whether we have fallen back to sort-based aggregation or not.
private[this] var sortBased: Boolean = false

private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _

// Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers
private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = {
val newExpressions = aggregateExpressions.map {
case agg @ AggregateExpression(_, Partial, _, _) =>
agg.copy(mode = PartialMerge)
case agg @ AggregateExpression(_, Complete, _, _) =>
agg.copy(mode = Final)
case other => other
}
val newFunctions = initializeAggregateFunctions(newExpressions, 0)
val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes)
generateProcessRow(newExpressions, newFunctions, newInputAttributes)
}

// A safe projection used to do deep clone of input rows to prevent false sharing.
private[this] val safeProjection: Projection =
FromUnsafeProjection(outputAttributes.map(_.dataType))

/**
* Start processing input rows.
*/
processInputs()

override final def hasNext: Boolean = {
aggBufferIterator.hasNext
}

override final def next(): UnsafeRow = {
val entry = aggBufferIterator.next()
generateOutput(entry.groupingKey, entry.aggregationBuffer)
}

/**
* Generate an output row when there is no input and there is no grouping expression.
*/
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
if (groupingExpressions.isEmpty) {
val defaultAggregationBuffer = createNewAggregationBuffer()
generateOutput(UnsafeRow.createFromByteArray(0, 0), defaultAggregationBuffer)
} else {
throw new IllegalStateException(
"This method should not be called when groupingExpressions is not empty.")
}
}

// Creates a new aggregation buffer and initializes buffer values. This function should only be
// called under two cases:
//
// - when creating aggregation buffer for a new group in the hash map, and
// - when creating the re-used buffer for sort-based aggregation
private def createNewAggregationBuffer(): SpecificInternalRow = {
val bufferFieldTypes = aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType))
val buffer = new SpecificInternalRow(bufferFieldTypes)
initAggregationBuffer(buffer)
buffer
}

private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = {
// Initializes declarative aggregates' buffer values
expressionAggInitialProjection.target(buffer)(EmptyRow)
// Initializes imperative aggregates' buffer values
aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
}

private def getAggregationBufferByKey(
hashMap: ObjectAggregationMap, groupingKey: UnsafeRow): InternalRow = {
var aggBuffer = hashMap.getAggregationBuffer(groupingKey)

if (aggBuffer == null) {
aggBuffer = createNewAggregationBuffer()
hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer)
}

aggBuffer
}

// This function is used to read and process input rows. When processing input rows, it first uses
// hash-based aggregation by putting groups and their buffers in `hashMap`. If `hashMap` grows too
// large, it sorts the contents, spills them to disk, and creates a new map. At last, all sorted
// spills are merged together for sort-based aggregation.
private def processInputs(): Unit = {
// In-memory map to store aggregation buffer for hash-based aggregation.
val hashMap = new ObjectAggregationMap()

// If in-memory map is unable to stores all aggregation buffer, fallback to sort-based
// aggregation backed by sorted physical storage.
var sortBasedAggregationStore: SortBasedAggregator = null

if (groupingExpressions.isEmpty) {
// If there is no grouping expressions, we can just reuse the same buffer over and over again.
val groupingKey = groupingProjection.apply(null)
val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
while (inputRows.hasNext) {
val newInput = safeProjection(inputRows.next())
processRow(buffer, newInput)
}
} else {
while (inputRows.hasNext && !sortBased) {
val newInput = safeProjection(inputRows.next())
val groupingKey = groupingProjection.apply(newInput)
val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
processRow(buffer, newInput)

// The the hash map gets too large, makes a sorted spill and clear the map.
if (hashMap.size >= fallbackCountThreshold) {
logInfo(
s"Aggregation hash map reaches threshold " +
s"capacity ($fallbackCountThreshold entries), spilling and falling back to sort" +
s" based aggregation. You may change the threshold by adjust option " +
SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key
)

// Falls back to sort-based aggregation
sortBased = true

}
}

if (sortBased) {
val sortIteratorFromHashMap = hashMap
.dumpToExternalSorter(groupingAttributes, aggregateFunctions)
.sortedIterator()
sortBasedAggregationStore = new SortBasedAggregator(
sortIteratorFromHashMap,
StructType.fromAttributes(originalInputAttributes),
StructType.fromAttributes(groupingAttributes),
processRow,
mergeAggregationBuffers,
createNewAggregationBuffer())

while (inputRows.hasNext) {
// NOTE: The input row is always UnsafeRow
val unsafeInputRow = inputRows.next().asInstanceOf[UnsafeRow]
val groupingKey = groupingProjection.apply(unsafeInputRow)
sortBasedAggregationStore.addInput(groupingKey, unsafeInputRow)
}
}
}

if (sortBased) {
aggBufferIterator = sortBasedAggregationStore.destructiveIterator()
} else {
aggBufferIterator = hashMap.iterator
}
}
}

/**
* A class used to handle sort-based aggregation, used together with [[ObjectHashAggregateExec]].
*
* @param initialAggBufferIterator iterator that points to sorted input aggregation buffers
* @param inputSchema The schema of input row
* @param groupingSchema The schema of grouping key
* @param processRow Function to update the aggregation buffer with input rows
* @param mergeAggregationBuffers Function used to merge the input aggregation buffers into existing
* aggregation buffers
* @param makeEmptyAggregationBuffer Creates an empty aggregation buffer
*
* @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]].
*/
class SortBasedAggregator(
initialAggBufferIterator: KVIterator[UnsafeRow, UnsafeRow],
inputSchema: StructType,
groupingSchema: StructType,
processRow: (InternalRow, InternalRow) => Unit,
mergeAggregationBuffers: (InternalRow, InternalRow) => Unit,
makeEmptyAggregationBuffer: => InternalRow) {

// external sorter to sort the input (grouping key + input row) with grouping key.
private val inputSorter = createExternalSorterForInput()
private val groupingKeyOrdering: BaseOrdering = GenerateOrdering.create(groupingSchema)

def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = {
inputSorter.insertKV(groupingKey, inputRow)
}

/**
* Returns a destructive iterator of AggregationBufferEntry.
* Notice: it is illegal to call any method after `destructiveIterator()` has been called.
*/
def destructiveIterator(): Iterator[AggregationBufferEntry] = {
new Iterator[AggregationBufferEntry] {
val inputIterator = inputSorter.sortedIterator()
var hasNextInput: Boolean = inputIterator.next()
var hasNextAggBuffer: Boolean = initialAggBufferIterator.next()
private var result: AggregationBufferEntry = _
private var groupingKey: UnsafeRow = _

override def hasNext(): Boolean = {
result != null || findNextSortedGroup()
}

override def next(): AggregationBufferEntry = {
val returnResult = result
result = null
returnResult
}

// Two-way merges initialAggBufferIterator and inputIterator
private def findNextSortedGroup(): Boolean = {
if (hasNextInput || hasNextAggBuffer) {
// Find smaller key of the initialAggBufferIterator and initialAggBufferIterator
groupingKey = findGroupingKey()
result = new AggregationBufferEntry(groupingKey, makeEmptyAggregationBuffer)

// Firstly, update the aggregation buffer with input rows.
while (hasNextInput &&
groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
processRow(result.aggregationBuffer, inputIterator.getValue)
hasNextInput = inputIterator.next()
}

// Secondly, merge the aggregation buffer with existing aggregation buffers.
// NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should
// be called after calling processRow.
while (hasNextAggBuffer &&
groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue)
hasNextAggBuffer = initialAggBufferIterator.next()
}

true
} else {
false
}
}

private def findGroupingKey(): UnsafeRow = {
var newGroupingKey: UnsafeRow = null
if (!hasNextInput) {
newGroupingKey = initialAggBufferIterator.getKey
} else if (!hasNextAggBuffer) {
newGroupingKey = inputIterator.getKey
} else {
val compareResult =
groupingKeyOrdering.compare(inputIterator.getKey, initialAggBufferIterator.getKey)
if (compareResult <= 0) {
newGroupingKey = inputIterator.getKey
} else {
newGroupingKey = initialAggBufferIterator.getKey
}
}

if (groupingKey == null) {
groupingKey = newGroupingKey.copy()
} else {
groupingKey.copyFrom(newGroupingKey)
}
groupingKey
}
}
}

private def createExternalSorterForInput(): UnsafeKVExternalSorter = {
new UnsafeKVExternalSorter(
groupingSchema,
inputSchema,
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes,
SparkEnv.get.conf.getLong(
"spark.shuffle.spill.numElementsForceSpillThreshold",
UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
null
)
}
}
Loading

0 comments on commit 27daf6b

Please sign in to comment.