Skip to content

Commit

Permalink
[SPARK-9577][SQL] Surface concrete iterator types in various sort cla…
Browse files Browse the repository at this point in the history
…sses.

We often return abstract iterator types in various sort-related classes (e.g. UnsafeKVExternalSorter). It is actually better to return a more concrete type, so the callsite uses that type and JIT can inline the iterator calls.

Author: Reynold Xin <[email protected]>

Closes apache#7911 from rxin/surface-concrete-type and squashes the following commits:

0422add [Reynold Xin] [SPARK-9577][SQL] Surface concrete iterator types in various sort classes.
  • Loading branch information
rxin committed Aug 4, 2015
1 parent 3b0e444 commit 5eb89f6
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ public void insertKVRecord(

public UnsafeSorterIterator getSortedIterator() throws IOException {
assert(inMemSorter != null);
final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
if (spillWriters.isEmpty()) {
return inMemoryIterator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public void insertRecord(long recordPointer, long keyPrefix) {
pointerArrayInsertPosition++;
}

private static final class SortedIterator extends UnsafeSorterIterator {
public static final class SortedIterator extends UnsafeSorterIterator {

private final TaskMemoryManager memoryManager;
private final int sortBufferInsertPosition;
Expand All @@ -144,7 +144,7 @@ private static final class SortedIterator extends UnsafeSorterIterator {
private long keyPrefix;
private int recordLength;

SortedIterator(
private SortedIterator(
TaskMemoryManager memoryManager,
int sortBufferInsertPosition,
long[] sortBuffer) {
Expand Down Expand Up @@ -186,7 +186,7 @@ public void loadNext() {
* Return an iterator over record pointers in sorted order. For efficiency, all calls to
* {@code next()} will return the same mutable object.
*/
public UnsafeSorterIterator getSortedIterator() {
public SortedIterator getSortedIterator() {
sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,66 +134,15 @@ public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException {
value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
}

public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() throws IOException {
public KVSorterIterator sortedIterator() throws IOException {
try {
final UnsafeSorterIterator underlying = sorter.getSortedIterator();
if (!underlying.hasNext()) {
// Since we won't ever call next() on an empty iterator, we need to clean up resources
// here in order to prevent memory leaks.
cleanupResources();
}

return new KVIterator<UnsafeRow, UnsafeRow>() {
private UnsafeRow key = new UnsafeRow();
private UnsafeRow value = new UnsafeRow();
private int numKeyFields = keySchema.size();
private int numValueFields = valueSchema.size();

@Override
public boolean next() throws IOException {
try {
if (underlying.hasNext()) {
underlying.loadNext();

Object baseObj = underlying.getBaseObject();
long recordOffset = underlying.getBaseOffset();
int recordLen = underlying.getRecordLength();

// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
int valueLen = recordLen - keyLen - 4;

key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);

return true;
} else {
key = null;
value = null;
cleanupResources();
return false;
}
} catch (IOException e) {
cleanupResources();
throw e;
}
}

@Override
public UnsafeRow getKey() {
return key;
}

@Override
public UnsafeRow getValue() {
return value;
}

@Override
public void close() {
cleanupResources();
}
};
return new KVSorterIterator(underlying);
} catch (IOException e) {
cleanupResources();
throw e;
Expand Down Expand Up @@ -233,4 +182,61 @@ public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff
return ordering.compare(row1, row2);
}
}

public class KVSorterIterator extends KVIterator<UnsafeRow, UnsafeRow> {
private UnsafeRow key = new UnsafeRow();
private UnsafeRow value = new UnsafeRow();
private final int numKeyFields = keySchema.size();
private final int numValueFields = valueSchema.size();
private final UnsafeSorterIterator underlying;

private KVSorterIterator(UnsafeSorterIterator underlying) {
this.underlying = underlying;
}

@Override
public boolean next() throws IOException {
try {
if (underlying.hasNext()) {
underlying.loadNext();

Object baseObj = underlying.getBaseObject();
long recordOffset = underlying.getBaseOffset();
int recordLen = underlying.getRecordLength();

// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
int valueLen = recordLen - keyLen - 4;

key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);

return true;
} else {
key = null;
value = null;
cleanupResources();
return false;
}
} catch (IOException e) {
cleanupResources();
throw e;
}
}

@Override
public UnsafeRow getKey() {
return key;
}

@Override
public UnsafeRow getValue() {
return value;
}

@Override
public void close() {
cleanupResources();
}
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.execution.{UnsafeKeyValueSorter, UnsafeFixedWidthAggregationMap}
import org.apache.spark.unsafe.KVIterator
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -230,7 +230,7 @@ class UnsafeHybridAggregationIterator(
}

// Step 5: Get the sorted iterator from the externalSorter.
val sortedKVIterator: KVIterator[UnsafeRow, UnsafeRow] = externalSorter.sortedIterator()
val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator()

// Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator.
// For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator
Expand Down Expand Up @@ -368,31 +368,5 @@ object UnsafeHybridAggregationIterator {
newMutableProjection,
outputsUnsafeRows)
}

def createFromKVIterator(
groupingKeyAttributes: Seq[Attribute],
valueAttributes: Seq[Attribute],
inputKVIterator: KVIterator[UnsafeRow, InternalRow],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
new UnsafeHybridAggregationIterator(
groupingKeyAttributes,
valueAttributes,
inputKVIterator,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
outputsUnsafeRows)
}
// scalastyle:on
}

0 comments on commit 5eb89f6

Please sign in to comment.