Skip to content

Commit

Permalink
[Datasets] Minimize truncation on balanced splits. (ray-project#18953)
Browse files Browse the repository at this point in the history
* Minimize truncation on balanced splits.

* Refactor into subroutines.

* Feedback and fixes.
  • Loading branch information
clarkzinzow authored Sep 30, 2021
1 parent 5709c65 commit 74b5d3d
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 13 deletions.
153 changes: 140 additions & 13 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,24 +402,150 @@ def split(self,
if n <= 0:
raise ValueError(f"The number of splits {n} is not positive.")

if n > self.num_blocks() and equal:
raise NotImplementedError(
f"The number of splits {n} > the number of dataset blocks "
f"{self.num_blocks()}, yet an equal split was requested.")

if locality_hints and len(locality_hints) != n:
raise ValueError(
f"The length of locality_hints {len(locality_hints)} "
"doesn't equal the number of splits {n}.")

# TODO(ekl) we could do better than truncation here. This could be a
# problem if block sizes are very skewed.
def equalize(splits: List[Dataset[T]]) -> List[Dataset[T]]:
def _partition_splits(splits: List[Dataset[T]], part_size: int,
counts_cache: Dict[str, int]):
"""Partition splits into two sets: splits that are smaller than the
target size and splits that are larger than the target size.
"""
splits = sorted(splits, key=lambda s: counts_cache[s._get_uuid()])
idx = next(i for i, split in enumerate(splits)
if counts_cache[split._get_uuid()] >= part_size)
return splits[:idx], splits[idx:]

def _equalize_larger_splits(splits: List[Dataset[T]], target_size: int,
counts_cache: Dict[str, int],
num_splits_required: int):
"""Split each split into one or more subsplits that are each the
target size, with at most one leftover split that's smaller
than the target size.
This assume that the given splits are sorted in ascending order.
"""
new_splits = []
leftovers = []
for split in splits:
size = counts_cache[split._get_uuid()]
if size == target_size:
new_splits.append(split)
continue
split_indices = list(range(target_size, size, target_size))
split_splits = split.split_at_indices(split_indices)
last_split_size = split_splits[-1].count()
if last_split_size < target_size:
# Last split is smaller than the target size, save it for
# our unioning of small splits.
leftover = split_splits.pop()
leftovers.append(leftover)
counts_cache[leftover._get_uuid()] = leftover.count()
if len(new_splits) + len(split_splits) >= num_splits_required:
# Short-circuit if the new splits will make us reach the
# desired number of splits.
new_splits.extend(
split_splits[:num_splits_required - len(new_splits)])
break
new_splits.extend(split_splits)
return new_splits, leftovers

def _equalize_smaller_splits(
splits: List[Dataset[T]], target_size: int,
counts_cache: Dict[str, int], num_splits_required: int):
"""Union small splits up to the target split size.
This assume that the given splits are sorted in ascending order.
"""
new_splits = []
union_buffer = []
union_buffer_size = 0
low = 0
high = len(splits) - 1
while low <= high:
# Union small splits up to the target split size.
low_split = splits[low]
low_count = counts_cache[low_split._get_uuid()]
high_split = splits[high]
high_count = counts_cache[high_split._get_uuid()]
if union_buffer_size + high_count <= target_size:
# Try to add the larger split to the union buffer first.
union_buffer.append(high_split)
union_buffer_size += high_count
high -= 1
elif union_buffer_size + low_count <= target_size:
union_buffer.append(low_split)
union_buffer_size += low_count
low += 1
else:
# Neither the larger nor smaller split fit in the union
# buffer, so we split the smaller split into a subsplit
# that will fit into the union buffer and a leftover
# subsplit that we add back into the candidate split list.
diff = target_size - union_buffer_size
diff_split, new_low_split = low_split.split_at_indices(
[diff])
union_buffer.append(diff_split)
union_buffer_size += diff
# We overwrite the old low split and don't advance the low
# pointer since (1) the old low split can be discarded,
# (2) the leftover subsplit is guaranteed to be smaller
# than the old low split, and (3) the low split should be
# the smallest split in the candidate split list, which is
# this subsplit.
splits[low] = new_low_split
counts_cache[new_low_split._get_uuid()] = low_count - diff
if union_buffer_size == target_size:
# Once the union buffer is full, we union together the
# splits.
assert len(union_buffer) > 1, union_buffer
first_ds = union_buffer[0]
new_split = first_ds.union(*union_buffer[1:])
new_splits.append(new_split)
# Clear the union buffer.
union_buffer = []
union_buffer_size = 0
if len(new_splits) == num_splits_required:
# Short-circuit if we've reached the desired number of
# splits.
break
return new_splits

def equalize(splits: List[Dataset[T]],
num_splits: int) -> List[Dataset[T]]:
if not equal:
return splits
lower_bound = min([s.count() for s in splits])
assert lower_bound > 0, splits
return [s.limit(lower_bound) for s in splits]
counts = {s._get_uuid(): s.count() for s in splits}
total_rows = sum(counts.values())
# Number of rows for each split.
target_size = total_rows // num_splits

# Partition splits.
smaller_splits, larger_splits = _partition_splits(
splits, target_size, counts)
if len(smaller_splits) == 0 and num_splits < len(splits):
# All splits are already equal.
return splits

# Split larger splits.
new_splits, leftovers = _equalize_larger_splits(
larger_splits, target_size, counts, num_splits)
# Short-circuit if we've already reached the desired number of
# splits.
if len(new_splits) == num_splits:
return new_splits
# Add leftovers to small splits and re-sort.
smaller_splits += leftovers
smaller_splits = sorted(
smaller_splits, key=lambda s: counts[s._get_uuid()])

# Union smaller splits.
new_splits_small = _equalize_smaller_splits(
smaller_splits, target_size, counts,
num_splits - len(new_splits))
new_splits.extend(new_splits_small)
return new_splits

block_refs = list(self._blocks)
metadata_mapping = {
Expand All @@ -433,7 +559,8 @@ def equalize(splits: List[Dataset[T]]) -> List[Dataset[T]]:
BlockList(
list(blocks), [metadata_mapping[b] for b in blocks]))
for blocks in np.array_split(block_refs, n)
])
if not equal or len(blocks) > 0
], n)

# If the locality_hints is set, we use a two-round greedy algorithm
# to co-locate the blocks with the actors based on block
Expand Down Expand Up @@ -532,7 +659,7 @@ def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]:
[metadata_mapping[b]
for b in allocation_per_actor[actor]]))
for actor in locality_hints
])
], n)

def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]:
"""Split the dataset at the given indices (like np.split).
Expand Down
85 changes: 85 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import ray

from ray.tests.conftest import * # noqa
from ray.data.dataset import Dataset
from ray.data.datasource import DummyOutputDatasource
from ray.data.datasource.csv_datasource import CSVDatasource
from ray.data.block import BlockAccessor
from ray.data.impl.block_list import BlockList
from ray.data.datasource.file_based_datasource import _unwrap_protocol
from ray.data.extensions.tensor_extension import (
TensorArray, TensorDtype, ArrowTensorType, ArrowTensorArray)
Expand Down Expand Up @@ -142,6 +144,89 @@ def __call__(self, x):
assert len(actor_reuse) == 10, actor_reuse


@pytest.mark.parametrize(
"block_sizes,num_splits",
[
( # Test baseline.
[3, 6, 3], 3),
( # Already balanced.
[3, 3, 3], 3),
( # Row truncation.
[3, 6, 4], 3),
( # Row truncation, smaller number of blocks.
[3, 6, 2, 3], 3),
( # Row truncation, larger number of blocks.
[5, 6, 2, 5], 5),
( # All smaller but one.
[1, 1, 1, 1, 6], 5),
( # All larger but one.
[4, 4, 4, 4, 1], 5),
( # Single block.
[2], 2),
( # Single split.
[2, 5], 1),
])
def test_equal_split_balanced(ray_start_regular_shared, block_sizes,
num_splits):
_test_equal_split_balanced(block_sizes, num_splits)


def _test_equal_split_balanced(block_sizes, num_splits):
blocks = []
metadata = []
total_rows = 0
for block_size in block_sizes:
block = list(range(total_rows, total_rows + block_size))
blocks.append(ray.put(block))
metadata.append(BlockAccessor.for_block(block).get_metadata(None))
total_rows += block_size
block_list = BlockList(blocks, metadata)
ds = Dataset(block_list)

splits = ds.split(num_splits, equal=True)
split_counts = [split.count() for split in splits]
assert len(split_counts) == num_splits
expected_block_size = total_rows // num_splits
# Check that all splits are the expected size.
assert all([count == expected_block_size for count in split_counts])
expected_total_rows = sum(split_counts)
# Check that the expected number of rows were dropped.
assert total_rows - expected_total_rows == total_rows % num_splits
# Check that all rows are unique (content check).
split_rows = [row for split in splits for row in split.take(total_rows)]
assert len(set(split_rows)) == len(split_rows)


def test_equal_split_balanced_grid(ray_start_regular_shared):

# Tests balanced equal splitting over a grid of configurations.
# Grid: num_blocks x num_splits x num_rows_block_1 x ... x num_rows_block_n
seed = int(time.time())
print(f"Seeding RNG for test_equal_split_balanced_grid with: {seed}")
random.seed(seed)
max_num_splits = 20
num_splits_samples = 5
max_num_blocks = 50
max_num_rows_per_block = 100
num_blocks_samples = 5
block_sizes_samples = 5
for num_splits in np.random.randint(
2, max_num_splits + 1, size=num_splits_samples):
for num_blocks in np.random.randint(
1, max_num_blocks + 1, size=num_blocks_samples):
block_sizes_list = [
np.random.randint(
1, max_num_rows_per_block + 1, size=num_blocks)
for _ in range(block_sizes_samples)
]
for block_sizes in block_sizes_list:
if sum(block_sizes) < num_splits:
min_ = math.ceil(num_splits / num_blocks)
block_sizes = np.random.randint(
min_, max_num_rows_per_block + 1, size=num_blocks)
_test_equal_split_balanced(block_sizes, num_splits)


@pytest.mark.parametrize("pipelined", [False, True])
def test_basic(ray_start_regular_shared, pipelined):
ds0 = ray.data.range(5)
Expand Down

0 comments on commit 74b5d3d

Please sign in to comment.