Skip to content

Commit

Permalink
[SPARK-39179][PYTHON][TESTS] Improve the test coverage for pyspark/sh…
Browse files Browse the repository at this point in the history
…uffle.py

### What changes were proposed in this pull request?
This PR add test cases for shuffle.py

### Why are the changes needed?
To cover corner test cases and increase coverage. This will increase the coverage of shuffle.py to close to 90%

### Does this PR introduce _any_ user-facing change?
No - test only

### How was this patch tested?
CI in this PR should test it out

Closes apache#36701 from pralabhkumar/rk_test_taskcontext.

Authored-by: pralabhkumar <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
pralabhkumar authored and dongjoon-hyun committed Jun 5, 2022
1 parent 6277fc7 commit bab70b1
Showing 1 changed file with 93 additions and 1 deletion.
94 changes: 93 additions & 1 deletion python/pyspark/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@
#
import random
import unittest
import tempfile
import os

from py4j.protocol import Py4JJavaError

from pyspark import shuffle, CPickleSerializer, SparkConf, SparkContext
from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
from pyspark.shuffle import (
Aggregator,
ExternalMerger,
ExternalSorter,
SimpleAggregator,
Merger,
ExternalGroupBy,
)


class MergerTests(unittest.TestCase):
Expand Down Expand Up @@ -54,6 +63,57 @@ def test_medium_dataset(self):
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.items()), sum(range(self.N)) * 3)

def test_shuffle_data_with_multiple_locations(self):
# SPARK-39179: Test shuffle of data with multiple location also check
# shuffle locations get randomized

with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2:
original = os.environ.get("SPARK_LOCAL_DIRS", None)
os.environ["SPARK_LOCAL_DIRS"] = tempdir1 + "," + tempdir2
try:
index_of_tempdir1 = [False, False]
for idx in range(10):
m = ExternalMerger(self.agg, 20)
if m.localdirs[0].startswith(tempdir1):
index_of_tempdir1[0] = True
elif m.localdirs[1].startswith(tempdir1):
index_of_tempdir1[1] = True
m.mergeValues(self.data)
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.items()), sum(range(self.N)))
self.assertTrue(
index_of_tempdir1[0] and (index_of_tempdir1[0] == index_of_tempdir1[1])
)
finally:
if original is not None:
os.environ["SPARK_LOCAL_DIRS"] = original
else:
del os.environ["SPARK_LOCAL_DIRS"]

def test_simple_aggregator_with_medium_dataset(self):
# SPARK-39179: Test Simple aggregator
agg = SimpleAggregator(lambda x, y: x + y)
m = ExternalMerger(agg, 20)
m.mergeValues(self.data)
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(v for k, v in m.items()), sum(range(self.N)))

def test_merger_not_implemented_error(self):
# SPARK-39179: Test Merger for error scenarios
agg = SimpleAggregator(lambda x, y: x + y)

class DummyMerger(Merger):
def __init__(self, agg):
Merger.__init__(self, agg)

dummy_merger = DummyMerger(agg)
with self.assertRaises(NotImplementedError):
dummy_merger.mergeValues(self.data)
with self.assertRaises(NotImplementedError):
dummy_merger.mergeCombiners(self.data)
with self.assertRaises(NotImplementedError):
dummy_merger.items()

def test_huge_dataset(self):
m = ExternalMerger(self.agg, 5, partitions=3)
m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10))
Expand Down Expand Up @@ -117,6 +177,38 @@ def legit_merge_combiners(x, y):
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))


class ExternalGroupByTests(unittest.TestCase):
def setUp(self):
self.N = 1 << 20
values = [i for i in range(self.N)]
keys = [i for i in range(2)]
import itertools

self.data = [value for value in itertools.product(keys, values)]
self.agg = Aggregator(
lambda x: [x], lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x
)

def test_medium_dataset(self):
# SPARK-39179: Test external group by for medium dataset
m = ExternalGroupBy(self.agg, 5, partitions=3)
m.mergeValues(self.data)
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.items()), 2 * sum(range(self.N)))

def test_dataset_with_keys_are_unsorted(self):
# SPARK-39179: Test external group when numbers of keys are greater than SORT KEY Limit.
m = ExternalGroupBy(self.agg, 5, partitions=3)
original = m.SORT_KEY_LIMIT
try:
m.SORT_KEY_LIMIT = 1
m.mergeValues(self.data)
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.items()), 2 * sum(range(self.N)))
finally:
m.SORT_KEY_LIMIT = original


class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
lst = list(range(1024))
Expand Down

0 comments on commit bab70b1

Please sign in to comment.