Skip to content

Commit

Permalink
[SPARK-3554] [PySpark] use broadcast automatically for large closure
Browse files Browse the repository at this point in the history
Py4j can not handle large string efficiently, so we should use broadcast for large closure automatically. (Broadcast use local filesystem to pass through data).

Author: Davies Liu <[email protected]>

Closes apache#2417 from davies/command and squashes the following commits:

fbf4e97 [Davies Liu] bugfix
aefd508 [Davies Liu] use broadcast automatically for large closure
  • Loading branch information
davies authored and JoshRosen committed Sep 19, 2014
1 parent 9306297 commit e77fa81
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
4 changes: 4 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,8 +2061,12 @@ def _jrdd(self):
self._jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if pickled_command > (1 << 20): # 1M
broadcast = self.ctx.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from array import array
from operator import itemgetter

from pyspark.rdd import RDD, PipelinedRDD
from pyspark.rdd import RDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
Expand Down Expand Up @@ -975,7 +975,11 @@ def registerFunction(self, name, f, returnType=StringType()):
command = (func,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
pickled_command = CloudPickleSerializer().dumps(command)
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if pickled_command > (1 << 20): # 1M
broadcast = self._sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self._sc._pickled_broadcast_vars],
self._sc._gateway._gateway_client)
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,12 @@ def test_large_broadcast(self):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)

def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
self.assertEquals(N, m)

def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))
b = self.sc.parallelize(range(100, 105))
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ def main(infile, outfile):
_broadcastRegistry[bid] = Broadcast(bid, value)
else:
bid = - bid - 1
_broadcastRegistry.remove(bid)
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
(func, deserializer, serializer) = command
init_time = time.time()
iterator = deserializer.load_stream(infile)
Expand Down

0 comments on commit e77fa81

Please sign in to comment.