Skip to content

Commit

Permalink
[SPARK-24257][SQL] LongToUnsafeRowMap calculate the new size may be w…
Browse files Browse the repository at this point in the history
…rong

LongToUnsafeRowMap has a mistake when growing its page array: it blindly grows to `oldSize * 2`, while the new record may be larger than `oldSize * 2`. Then we may have a malformed UnsafeRow when querying this map, whose actual data is smaller than its declared size, and the data is corrupted.

Author: sychen <[email protected]>

Closes apache#21311 from cxzl25/fix_LongToUnsafeRowMap_page_size.
  • Loading branch information
cxzl25 authored and cloud-fan committed May 24, 2018
1 parent 230f144 commit 8883401
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
def append(key: Long, row: UnsafeRow): Unit = {
val sizeInBytes = row.getSizeInBytes
if (sizeInBytes >= (1 << SIZE_BITS)) {
sys.error("Does not support row that is larger than 256M")
throw new UnsupportedOperationException("Does not support row that is larger than 256M")
}

if (key < minKey) {
Expand All @@ -567,19 +567,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
maxKey = key
}

// There is 8 bytes for the pointer to next value
if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) {
val used = page.length
if (used >= (1 << 30)) {
sys.error("Can not build a HashedRelation that is larger than 8G")
}
ensureAcquireMemory(used * 8L * 2)
val newPage = new Array[Long](used * 2)
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
cursor - Platform.LONG_ARRAY_OFFSET)
page = newPage
freeMemory(used * 8L)
}
grow(row.getSizeInBytes)

// copy the bytes of UnsafeRow
val offset = cursor
Expand Down Expand Up @@ -615,7 +603,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
growArray()
} else if (numKeys > array.length / 2 * 0.75) {
// The fill ratio should be less than 0.75
sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
throw new UnsupportedOperationException(
"Cannot build HashedRelation with more than 1/3 billions unique keys")
}
}
} else {
Expand All @@ -626,6 +615,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
}

private def grow(inputRowSize: Int): Unit = {
// There is 8 bytes for the pointer to next value
val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8
if (neededNumWords > page.length) {
if (neededNumWords > (1 << 30)) {
throw new UnsupportedOperationException(
"Can not build a HashedRelation that is larger than 8G")
}
val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30))
ensureAcquireMemory(newNumWords * 8L)
val newPage = new Array[Long](newNumWords.toInt)
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
cursor - Platform.LONG_ARRAY_OFFSET)
val used = page.length
page = newPage
freeMemory(used * 8L)
}
}

private def growArray(): Unit = {
var old_array = array
val n = array.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.CompactBuffer
Expand Down Expand Up @@ -254,6 +254,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
map.free()
}

test("SPARK-24257: insert big values into LongToUnsafeRowMap") {
val taskMemoryManager = new TaskMemoryManager(
new StaticMemoryManager(
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
Long.MaxValue,
Long.MaxValue,
1),
0)
val unsafeProj = UnsafeProjection.create(Array[DataType](StringType))
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)

val key = 0L
// the page array is initialized with length 1 << 17 (1M bytes),
// so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug
val bigStr = UTF8String.fromString("x" * (1 << 19))

map.append(key, unsafeProj(InternalRow(bigStr)))
map.optimize()

val resultRow = new UnsafeRow(1)
assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr)
map.free()
}

test("Spark-14521") {
val ser = new KryoSerializer(
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()
Expand Down

0 comments on commit 8883401

Please sign in to comment.