Skip to content

Commit

Permalink
[SPARK-39084][PYSPARK] Fix df.rdd.isEmpty() by using TaskContext to s…
Browse files Browse the repository at this point in the history
…top iterator on task completion

### What changes were proposed in this pull request?

This PR fixes the issue described in https://issues.apache.org/jira/browse/SPARK-39084 where calling `df.rdd.isEmpty()` on a particular dataset could result in a JVM crash and/or executor failure.

The issue was due to Python iterator not being synchronised with Java iterator so when the task is complete, the Python iterator continues to process data. We have introduced ContextAwareIterator as part of https://issues.apache.org/jira/browse/SPARK-33277 but we did not fix all of the places where this should be used.

### Why are the changes needed?

Fixes the JVM crash when checking isEmpty() on a dataset.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

I added a test case that reproduces the issue 100%. I confirmed that the test fails without the fix and passes with the fix.

Closes apache#36425 from sadikovi/fix-pyspark-iter-2.

Authored-by: Ivan Sadikov <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
sadikovi authored and HyukjinKwon committed May 2, 2022
1 parent 6479455 commit 9305cc7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
36 changes: 36 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tempfile
import time
import unittest
import uuid
from typing import cast

from pyspark.sql import SparkSession, Row
Expand Down Expand Up @@ -1176,6 +1177,41 @@ def test_df_show(self):
with self.assertRaisesRegex(TypeError, "Parameter 'truncate=foo'"):
df.show(truncate="foo")

def test_df_is_empty(self):
# SPARK-39084: Fix df.rdd.isEmpty() resulting in JVM crash.

# This particular example of DataFrame reproduces an issue in isEmpty call
# which could result in JVM crash.
data = []
for t in range(0, 10000):
id = str(uuid.uuid4())
if t == 0:
for i in range(0, 99):
data.append((id,))
elif t < 10:
for i in range(0, 75):
data.append((id,))
elif t < 100:
for i in range(0, 50):
data.append((id,))
elif t < 1000:
for i in range(0, 25):
data.append((id,))
else:
for i in range(0, 10):
data.append((id,))

tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
try:
df = self.spark.createDataFrame(data, ["col"])
df.coalesce(1).write.parquet(tmpPath)

res = self.spark.read.parquet(tmpPath).groupBy("col").count()
self.assertFalse(res.rdd.isEmpty())
finally:
shutil.rmtree(tmpPath)

@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._

import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler}

import org.apache.spark.{ContextAwareIterator, TaskContext}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -301,7 +302,7 @@ object EvaluatePython {
def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = {
rdd.mapPartitions { iter =>
registerPicklers() // let it called in executor
new SerDeUtil.AutoBatchedPickler(iter)
new SerDeUtil.AutoBatchedPickler(new ContextAwareIterator(TaskContext.get, iter))
}
}
}

0 comments on commit 9305cc7

Please sign in to comment.