Skip to content

Commit

Permalink
[SPARK-12002][STREAMING][PYSPARK] Fix python direct stream checkpoint…
Browse files Browse the repository at this point in the history
… recovery issue

Fixed a minor race condition in #10017

Closes #10017

Author: jerryshao <[email protected]>
Author: Shixiong Zhu <[email protected]>

Closes #10074 from zsxwing/review-pr10017.
  • Loading branch information
jerryshao authored and zsxwing committed Dec 1, 2015
1 parent e76431f commit f292018
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
49 changes: 49 additions & 0 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,55 @@ def test_topic_and_partition_equality(self):
self.assertNotEqual(topic_and_partition_a, topic_and_partition_c)
self.assertNotEqual(topic_and_partition_a, topic_and_partition_d)

@unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_direct_stream_transform_with_checkpoint(self):
"""Test the Python direct Kafka stream transform with checkpoint correctly recovered."""
topic = self._randomTopic()
sendData = {"a": 1, "b": 2, "c": 3}
kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
"auto.offset.reset": "smallest"}

self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)

offsetRanges = []

def transformWithOffsetRanges(rdd):
for o in rdd.offsetRanges():
offsetRanges.append(o)
return rdd

self.ssc.stop(False)
self.ssc = None
tmpdir = "checkpoint-test-%d" % random.randint(0, 10000)

def setup():
ssc = StreamingContext(self.sc, 0.5)
ssc.checkpoint(tmpdir)
stream = KafkaUtils.createDirectStream(ssc, [topic], kafkaParams)
stream.transform(transformWithOffsetRanges).count().pprint()
return ssc

try:
ssc1 = StreamingContext.getOrCreate(tmpdir, setup)
ssc1.start()
self.wait_for(offsetRanges, 1)
self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])

# To make sure some checkpoint is written
time.sleep(3)
ssc1.stop(False)
ssc1 = None

# Restart again to make sure the checkpoint is recovered correctly
ssc2 = StreamingContext.getOrCreate(tmpdir, setup)
ssc2.start()
ssc2.awaitTermination(3)
ssc2.stop(stopSparkContext=False, stopGraceFully=True)
ssc2 = None
finally:
shutil.rmtree(tmpdir)

@unittest.skipIf(sys.version >= "3", "long type not support")
def test_kafka_rdd_message_handler(self):
"""Test Python direct Kafka RDD MessageHandler."""
Expand Down
13 changes: 7 additions & 6 deletions python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __init__(self, ctx, func, *deserializers):
self.ctx = ctx
self.func = func
self.deserializers = deserializers
self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
self.failure = None

def rdd_wrapper(self, func):
self._rdd_wrapper = func
self.rdd_wrap_func = func
return self

def call(self, milliseconds, jrdds):
Expand All @@ -59,7 +59,7 @@ def call(self, milliseconds, jrdds):
if len(sers) < len(jrdds):
sers += (sers[0],) * (len(jrdds) - len(sers))

rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None
for jrdd, ser in zip(jrdds, sers)]
t = datetime.fromtimestamp(milliseconds / 1000.0)
r = self.func(t, *rdds)
Expand Down Expand Up @@ -101,16 +101,17 @@ def dumps(self, id):
self.failure = None
try:
func = self.gateway.gateway_property.pool[id]
return bytearray(self.serializer.dumps((func.func, func.deserializers)))
return bytearray(self.serializer.dumps((
func.func, func.rdd_wrap_func, func.deserializers)))
except:
self.failure = traceback.format_exc()

def loads(self, data):
# Clear the failure
self.failure = None
try:
f, deserializers = self.serializer.loads(bytes(data))
return TransformFunction(self.ctx, f, *deserializers)
f, wrap_func, deserializers = self.serializer.loads(bytes(data))
return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func)
except:
self.failure = traceback.format_exc()

Expand Down

0 comments on commit f292018

Please sign in to comment.