Skip to content

Commit

Permalink
[SPARK-10403] Allow UnsafeRowSerializer to work with tungsten-sort Sh…
Browse files Browse the repository at this point in the history
…uffleManager

This patch attempts to fix an issue where Spark SQL's UnsafeRowSerializer was incompatible with the `tungsten-sort` ShuffleManager.

Author: Josh Rosen <[email protected]>

Closes #8873 from JoshRosen/SPARK-10403.
  • Loading branch information
JoshRosen authored and marmbrus committed Sep 23, 2015
1 parent 27bfa9a commit a182080
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,9 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S
}

private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance {

/**
* Marks the end of a stream written with [[serializeStream()]].
*/
private[this] val EOF: Int = -1

/**
* Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
* length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
* The end of the stream is denoted by a record with the special length `EOF` (-1).
*/
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
Expand Down Expand Up @@ -92,7 +85,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst

override def close(): Unit = {
writeBuffer = null
dOut.writeInt(EOF)
dOut.close()
}
}
Expand All @@ -104,12 +96,20 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
private[this] var row: UnsafeRow = new UnsafeRow()
private[this] var rowTuple: (Int, UnsafeRow) = (0, row)
private[this] val EOF: Int = -1

override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = {
new Iterator[(Int, UnsafeRow)] {
private[this] var rowSize: Int = dIn.readInt()
if (rowSize == EOF) dIn.close()

private[this] def readSize(): Int = try {
dIn.readInt()
} catch {
case e: EOFException =>
dIn.close()
EOF
}

private[this] var rowSize: Int = readSize()
override def hasNext: Boolean = rowSize != EOF

override def next(): (Int, UnsafeRow) = {
Expand All @@ -118,7 +118,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
}
ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize)
rowSize = dIn.readInt() // read the next row's size
rowSize = readSize()
if (rowSize == EOF) { // We are returning the last row in this stream
dIn.close()
val _rowTuple = rowTuple
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.execution

import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream}

import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.Utils
Expand All @@ -41,7 +42,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
}
}

class UnsafeRowSerializerSuite extends SparkFunSuite {
class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {

private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
val converter = unsafeRowConverter(schema)
Expand Down Expand Up @@ -87,11 +88,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
}

test("close empty input stream") {
val baos = new ByteArrayOutputStream()
val dout = new DataOutputStream(baos)
dout.writeInt(-1) // EOF
dout.flush()
val input = new ClosableByteArrayInputStream(baos.toByteArray)
val input = new ClosableByteArrayInputStream(Array.empty)
val serializer = new UnsafeRowSerializer(numFields = 2).newInstance()
val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator
assert(!deserializerIter.hasNext)
Expand Down Expand Up @@ -143,4 +140,16 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
}
}
}

test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") {
val conf = new SparkConf()
.set("spark.shuffle.manager", "tungsten-sort")
sc = new SparkContext("local", "test", conf)
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)))
.asInstanceOf[RDD[Product2[Int, InternalRow]]]
val shuffled = new ShuffledRowRDD(rowsRDD, new UnsafeRowSerializer(2), 2)
shuffled.count()
}
}

0 comments on commit a182080

Please sign in to comment.