Skip to content

Commit

Permalink
[SPARK-31511][SQL][2.4] Make BytesToBytesMap iterators thread-safe
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR increases the thread safety of the `BytesToBytesMap`:
- It makes the `iterator()` and `destructiveIterator()` methods used their own `Location` object. This used to be shared, and this was causing issues when the map was being iterated over in two threads by two different iterators.
- Removes the `safeIterator()` function. This is not needed anymore.
- Improves the documentation of a couple of methods w.r.t. thread-safety.

### Why are the changes needed?
It is unexpected an iterator shares the object it is returning with all other iterators. This is a violation of the iterator contract, and it causes issues with iterators over a map that are consumed in different threads.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
add ut

Closes apache#29605 from cxzl25/SPARK-31511.

Lead-authored-by: sychen <[email protected]>
Co-authored-by: herman <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and cloud-fan committed Sep 8, 2020
1 parent f2bcc93 commit 277ccba
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -421,11 +421,11 @@ private void handleFailedDelete() {
*
* For efficiency, all calls to `next()` will return the same {@link Location} object.
*
* If any other lookups or operations are performed on this map while iterating over it, including
* `lookup()`, the behavior of the returned iterator is undefined.
* The returned iterator is thread-safe. However if the map is modified while iterating over it,
* the behavior of the returned iterator is undefined.
*/
public MapIterator iterator() {
return new MapIterator(numValues, loc, false);
return new MapIterator(numValues, new Location(), false);
}

/**
Expand All @@ -435,19 +435,20 @@ public MapIterator iterator() {
*
* For efficiency, all calls to `next()` will return the same {@link Location} object.
*
* If any other lookups or operations are performed on this map while iterating over it, including
* `lookup()`, the behavior of the returned iterator is undefined.
* The returned iterator is thread-safe. However if the map is modified while iterating over it,
* the behavior of the returned iterator is undefined.
*/
public MapIterator destructiveIterator() {
updatePeakMemoryUsed();
return new MapIterator(numValues, loc, true);
return new MapIterator(numValues, new Location(), true);
}

/**
* Looks up a key, and return a {@link Location} handle that can be used to test existence
* and read/write values.
*
* This function always return the same {@link Location} instance to avoid object allocation.
* This function always returns the same {@link Location} instance to avoid object allocation.
* This function is not thread-safe.
*/
public Location lookup(Object keyBase, long keyOffset, int keyLength) {
safeLookup(keyBase, keyOffset, keyLength, loc,
Expand All @@ -459,7 +460,8 @@ public Location lookup(Object keyBase, long keyOffset, int keyLength) {
* Looks up a key, and return a {@link Location} handle that can be used to test existence
* and read/write values.
*
* This function always return the same {@link Location} instance to avoid object allocation.
* This function always returns the same {@link Location} instance to avoid object allocation.
* This function is not thread-safe.
*/
public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) {
safeLookup(keyBase, keyOffset, keyLength, loc, hash);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray))
}

test("SPARK-31511: Make BytesToBytesMap iterators thread-safe") {
val ser = sparkContext.env.serializer.newInstance()
val key = Seq(BoundReference(0, LongType, false))

val unsafeProj = UnsafeProjection.create(
Seq(BoundReference(0, LongType, false), BoundReference(1, IntegerType, true)))
val rows = (0 until 10000).map(i => unsafeProj(InternalRow(Int.int2long(i), i + 1)).copy())
val unsafeHashed = UnsafeHashedRelation(rows.iterator, key, 1, mm)

val os = new ByteArrayOutputStream()
val thread1 = new Thread {
override def run(): Unit = {
val out = new ObjectOutputStream(os)
unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
out.flush()
}
}

val thread2 = new Thread {
override def run(): Unit = {
val threadOut = new ObjectOutputStream(new ByteArrayOutputStream())
unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(threadOut)
threadOut.flush()
}
}

thread1.start()
thread2.start()
thread1.join()
thread2.join()

val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
val os2 = new ByteArrayOutputStream()
val out2 = new ObjectOutputStream(os2)
unsafeHashed2.writeExternal(out2)
out2.flush()
assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray))
}

// This test require 4G heap to run, should run it manually
ignore("build HashedRelation that is larger than 1G") {
val unsafeProj = UnsafeProjection.create(
Expand Down

0 comments on commit 277ccba

Please sign in to comment.