Skip to content

Commit

Permalink
[SPARK-9753] [SQL] TungstenAggregate should also accept InternalRow i…
Browse files Browse the repository at this point in the history
…nstead of just UnsafeRow

https://issues.apache.org/jira/browse/SPARK-9753

This PR makes TungstenAggregate to accept `InternalRow` instead of just `UnsafeRow`. Also, it adds an `getAggregationBufferFromUnsafeRow` method to `UnsafeFixedWidthAggregationMap`. It is useful when we already have grouping keys stored in `UnsafeRow`s. Finally, it wraps `InputStream` and `OutputStream` in `UnsafeRowSerializer` with `BufferedInputStream` and `BufferedOutputStream`, respectively.

Author: Yin Huai <[email protected]>

Closes apache#8041 from yhuai/joinedRowForProjection and squashes the following commits:

7753e34 [Yin Huai] Use BufferedInputStream and BufferedOutputStream.
d68b74e [Yin Huai] Use joinedRow instead of UnsafeRowJoiner.
e93c009 [Yin Huai] Add getAggregationBufferFromUnsafeRow for cases that the given groupingKeyRow is already an UnsafeRow.

(cherry picked from commit c564b27)
Signed-off-by: Reynold Xin <[email protected]>
  • Loading branch information
yhuai authored and rxin committed Aug 8, 2015
1 parent 5598b62 commit 47e4735
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ public UnsafeFixedWidthAggregationMap(
public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);

return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
}

public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) {
// Probe our map using the serialized key
final BytesToBytesMap.Location loc = map.lookup(
unsafeGroupingKeyRow.getBaseObject(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,14 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
*/
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
// When `out` is backed by ChainedBufferOutputStream, we will get an
// UnsupportedOperationException when we call dOut.writeInt because it internally calls
// ChainedBufferOutputStream's write(b: Int), which is not supported.
// To workaround this issue, we create an array for sorting the int value.
// To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and
// run SparkSqlSerializer2SortMergeShuffleSuite.
private[this] var intBuffer: Array[Byte] = new Array[Byte](4)
private[this] val dOut: DataOutputStream = new DataOutputStream(out)
private[this] val dOut: DataOutputStream =
new DataOutputStream(new BufferedOutputStream(out))

override def writeValue[T: ClassTag](value: T): SerializationStream = {
val row = value.asInstanceOf[UnsafeRow]
val size = row.getSizeInBytes
// This part is based on DataOutputStream's writeInt.
// It is for dOut.writeInt(row.getSizeInBytes).
intBuffer(0) = ((size >>> 24) & 0xFF).toByte
intBuffer(1) = ((size >>> 16) & 0xFF).toByte
intBuffer(2) = ((size >>> 8) & 0xFF).toByte
intBuffer(3) = ((size >>> 0) & 0xFF).toByte
dOut.write(intBuffer, 0, 4)

row.writeToStream(out, writeBuffer)

dOut.writeInt(row.getSizeInBytes)
row.writeToStream(dOut, writeBuffer)
this
}

Expand All @@ -105,15 +92,14 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst

override def close(): Unit = {
writeBuffer = null
intBuffer = null
dOut.writeInt(EOF)
dOut.close()
}
}

override def deserializeStream(in: InputStream): DeserializationStream = {
new DeserializationStream {
private[this] val dIn: DataInputStream = new DataInputStream(in)
private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in))
// 1024 is a default buffer size; this buffer will grow to accommodate larger rows
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
private[this] var row: UnsafeRow = new UnsafeRow()
Expand All @@ -129,7 +115,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
if (rowBuffer.length < rowSize) {
rowBuffer = new Array[Byte](rowSize)
}
ByteStreams.readFully(in, rowBuffer, 0, rowSize)
ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize)
rowSize = dIn.readInt() // read the next row's size
if (rowSize == EOF) { // We are returning the last row in this stream
Expand Down Expand Up @@ -163,7 +149,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
if (rowBuffer.length < rowSize) {
rowBuffer = new Array[Byte](rowSize)
}
ByteStreams.readFully(in, rowBuffer, 0, rowSize)
ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize)
row.asInstanceOf[T]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ case class TungstenAggregate(

override def canProcessUnsafeRows: Boolean = true

override def canProcessSafeRows: Boolean = false
override def canProcessSafeRows: Boolean = true

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

Expand Down Expand Up @@ -77,7 +77,7 @@ case class TungstenAggregate(
resultExpressions,
newMutableProjection,
child.output,
iter.asInstanceOf[Iterator[UnsafeRow]],
iter,
testFallbackStartsAt)

if (!hasInput && groupingExpressions.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.types.StructType

Expand All @@ -46,8 +47,7 @@ import org.apache.spark.sql.types.StructType
* processing input rows from inputIter, and generating output
* rows.
* - Part 3: Methods and fields used by hash-based aggregation.
* - Part 4: The function used to switch this iterator from hash-based
* aggregation to sort-based aggregation.
* - Part 4: Methods and fields used when we switch to sort-based aggregation.
* - Part 5: Methods and fields used by sort-based aggregation.
* - Part 6: Loads input and process input rows.
* - Part 7: Public methods of this iterator.
Expand Down Expand Up @@ -82,7 +82,7 @@ class TungstenAggregationIterator(
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
inputIter: Iterator[UnsafeRow],
inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[Int])
extends Iterator[UnsafeRow] with Logging {

Expand Down Expand Up @@ -174,13 +174,10 @@ class TungstenAggregationIterator(

// Creates a function used to process a row based on the given inputAttributes.
private def generateProcessRow(
inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = {
inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = {

val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes)
val inputSchema = StructType.fromAttributes(inputAttributes)
val unsafeRowJoiner =
GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema)
val joinedRow = new JoinedRow()

aggregationMode match {
// Partial-only
Expand All @@ -189,9 +186,9 @@ class TungstenAggregationIterator(
val algebraicUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()

(currentBuffer: UnsafeRow, row: UnsafeRow) => {
(currentBuffer: UnsafeRow, row: InternalRow) => {
algebraicUpdateProjection.target(currentBuffer)
algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
algebraicUpdateProjection(joinedRow(currentBuffer, row))
}

// PartialMerge-only or Final-only
Expand All @@ -203,10 +200,10 @@ class TungstenAggregationIterator(
mergeExpressions,
aggregationBufferAttributes ++ inputAttributes)()

(currentBuffer: UnsafeRow, row: UnsafeRow) => {
(currentBuffer: UnsafeRow, row: InternalRow) => {
// Process all algebraic aggregate functions.
algebraicMergeProjection.target(currentBuffer)
algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row))
algebraicMergeProjection(joinedRow(currentBuffer, row))
}

// Final-Complete
Expand All @@ -233,8 +230,8 @@ class TungstenAggregationIterator(
val completeAlgebraicUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()

(currentBuffer: UnsafeRow, row: UnsafeRow) => {
val input = unsafeRowJoiner.join(currentBuffer, row)
(currentBuffer: UnsafeRow, row: InternalRow) => {
val input = joinedRow(currentBuffer, row)
// For all aggregate functions with mode Complete, update the given currentBuffer.
completeAlgebraicUpdateProjection.target(currentBuffer)(input)

Expand All @@ -253,14 +250,14 @@ class TungstenAggregationIterator(
val completeAlgebraicUpdateProjection =
newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()

(currentBuffer: UnsafeRow, row: UnsafeRow) => {
(currentBuffer: UnsafeRow, row: InternalRow) => {
completeAlgebraicUpdateProjection.target(currentBuffer)
// For all aggregate functions with mode Complete, update the given currentBuffer.
completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
completeAlgebraicUpdateProjection(joinedRow(currentBuffer, row))
}

// Grouping only.
case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {}
case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {}

case other =>
throw new IllegalStateException(
Expand All @@ -272,27 +269,29 @@ class TungstenAggregationIterator(
private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = {

val groupingAttributes = groupingExpressions.map(_.toAttribute)
val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
val bufferSchema = StructType.fromAttributes(bufferAttributes)
val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)

aggregationMode match {
// Partial-only or PartialMerge-only: every output row is basically the values of
// the grouping expressions and the corresponding aggregation buffer.
case (Some(Partial), None) | (Some(PartialMerge), None) =>
val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
val bufferSchema = StructType.fromAttributes(bufferAttributes)
val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)

(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
}

// Final-only, Complete-only and Final-Complete: a output row is generated based on
// resultExpressions.
case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
val joinedRow = new JoinedRow()
val resultProjection =
UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes)

(currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer))
resultProjection(joinedRow(currentGroupingKey, currentBuffer))
}

// Grouping-only: a output row is generated from values of grouping expressions.
Expand All @@ -316,7 +315,7 @@ class TungstenAggregationIterator(

// A function used to process a input row. Its first argument is the aggregation buffer
// and the second argument is the input row.
private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit =
private[this] var processRow: (UnsafeRow, InternalRow) => Unit =
generateProcessRow(originalInputAttributes)

// A function used to generate output rows based on the grouping keys (first argument)
Expand Down Expand Up @@ -354,7 +353,7 @@ class TungstenAggregationIterator(
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
val groupingKey = groupProjection.apply(newInput)
val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey)
val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
if (buffer == null) {
// buffer == null means that we could not allocate more memory.
// Now, we need to spill the map and switch to sort-based aggregation.
Expand All @@ -374,7 +373,7 @@ class TungstenAggregationIterator(
val newInput = inputIter.next()
val groupingKey = groupProjection.apply(newInput)
val buffer: UnsafeRow = if (i < fallbackStartsAt) {
hashMap.getAggregationBuffer(groupingKey)
hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
} else {
null
}
Expand All @@ -397,7 +396,7 @@ class TungstenAggregationIterator(
private[this] var mapIteratorHasNext: Boolean = false

///////////////////////////////////////////////////////////////////////////
// Part 3: Methods and fields used by sort-based aggregation.
// Part 4: Methods and fields used when we switch to sort-based aggregation.
///////////////////////////////////////////////////////////////////////////

// This sorter is used for sort-based aggregation. It is initialized as soon as
Expand All @@ -407,7 +406,7 @@ class TungstenAggregationIterator(
/**
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
*/
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = {
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
logInfo("falling back to sort based aggregation.")
// Step 1: Get the ExternalSorter containing sorted entries of the map.
externalSorter = hashMap.destructAndCreateExternalSorter()
Expand Down

0 comments on commit 47e4735

Please sign in to comment.