Skip to content

Commit

Permalink
Fix for sampling error in NumPy v1.9 [SPARK-3995][PYSPARK]
Browse files Browse the repository at this point in the history
Change maximum value for default seed during RDD sampling so that it is strictly less than 2 ** 32. This prevents a bug in the most recent version of NumPy, which cannot accept random seeds above this bound.

Adds an extra test that uses the default seed (instead of setting it manually, as in the docstrings).

mengxr

Author: freeman <[email protected]>

Closes apache#2889 from freeman-lab/pyspark-sampling and squashes the following commits:

dc385ef [freeman] Change maximum value for default seed
  • Loading branch information
freeman-lab authored and mengxr committed Oct 22, 2014
1 parent f05e09b commit 97cf19f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/rddsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, withReplacement, seed=None):
"Falling back to default random generator for sampling.")
self._use_numpy = False

self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._seed = seed if seed is not None else random.randint(0, 2 ** 32 - 1)
self._withReplacement = withReplacement
self._random = None
self._split = None
Expand All @@ -47,7 +47,7 @@ def initRandomGenerator(self, split):
for _ in range(0, split):
# discard the next few values in the sequence to have a
# different seed for the different splits
self._random.randint(0, sys.maxint)
self._random.randint(0, 2 ** 32 - 1)

self._split = split
self._rand_initialized = True
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,12 @@ def test_deleting_input_files(self):
os.unlink(tempFile.name)
self.assertRaises(Exception, lambda: filtered_data.count())

def test_sampling_default_seed(self):
# Test for SPARK-3995 (default seed setting)
data = self.sc.parallelize(range(1000), 1)
subset = data.takeSample(False, 10)
self.assertEqual(len(subset), 10)

def testAggregateByKey(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)

Expand Down

0 comments on commit 97cf19f

Please sign in to comment.