diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index e32ad9c036ad4..7c9dc8e5f88ef 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.hash import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -35,8 +36,8 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, - Serializer.getSerializer(dep.serializer)) + val ser = Serializer.getSerializer(dep.serializer) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { @@ -54,16 +55,13 @@ private[spark] class HashShuffleReader[K, C]( // Sort the output if there is a sort ordering defined. dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => - // Define a Comparator for the whole record based on the key Ordering. - val cmp = new Ordering[Product2[K, C]] { - override def compare(o1: Product2[K, C], o2: Product2[K, C]): Int = { - keyOrd.compare(o1._1, o2._1) - } - } - val sortBuffer: Array[Product2[K, C]] = aggregatedIter.toArray - // TODO: do external sort. - scala.util.Sorting.quickSort(sortBuffer)(cmp) - sortBuffer.iterator + // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, + // the ExternalSorter won't spill to disk. + val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) + sorter.write(aggregatedIter) + context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled + sorter.iterator case None => aggregatedIter } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index ddb5df40360e9..65a71e5a83698 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -190,6 +190,11 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") } } + + // sortByKey - should spill ~17 times + val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultE = rddE.sortByKey().collect().toSeq + assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) } test("spilling in local cluster with many reduce tasks") { @@ -256,6 +261,11 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") } } + + // sortByKey - should spill ~8 times per executor + val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultE = rddE.sortByKey().collect().toSeq + assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) } test("cleanup of intermediate files in sorter") {