From c93e8aecc981a165be24ec20ad08a4190892e983 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 16 Mar 2023 14:06:00 -0700 Subject: [PATCH] [data] [streaming] Support iterating over streaming_split() iterators multiple times (#33303) In order to support Train integration, the iterators returned from Dataset.streaming_split() must be repeatable, as with the other iterators. --- doc/source/ray-observability/ray-logging.rst | 2 + .../data/_internal/dataset_iterator_impl.py | 8 +- .../data/_internal/execution/interfaces.py | 3 + python/ray/data/_internal/progress_bar.py | 2 +- .../stream_split_dataset_iterator.py | 97 ++++++++++++--- .../data/tests/test_streaming_integration.py | 110 ++++++++++++++---- python/ray/experimental/tqdm_ray.py | 13 ++- python/ray/tune/progress_reporter.py | 5 +- 8 files changed, 200 insertions(+), 40 deletions(-) diff --git a/doc/source/ray-observability/ray-logging.rst b/doc/source/ray-observability/ray-logging.rst index d95407d344ae..d2c51ff5fd6e 100644 --- a/doc/source/ray-observability/ray-logging.rst +++ b/doc/source/ray-observability/ray-logging.rst @@ -95,6 +95,8 @@ Limitations: - Only a subset of tqdm functionality is supported. Refer to the ray_tqdm `implementation `__ for more details. - Performance may be poor if there are more than a couple thousand updates per second (updates are not batched). +Tip: To avoid `print` statements from the driver conflicting with tqdm output, use `ray.experimental.tqdm_ray.safe_print` instead. + How to set up loggers ~~~~~~~~~~~~~~~~~~~~~ When using ray, all of the tasks and actors are executed remotely in Ray's worker processes. diff --git a/python/ray/data/_internal/dataset_iterator_impl.py b/python/ray/data/_internal/dataset_iterator_impl.py index 17a1577e9edc..ee94a0c90e26 100644 --- a/python/ray/data/_internal/dataset_iterator_impl.py +++ b/python/ray/data/_internal/dataset_iterator_impl.py @@ -5,6 +5,7 @@ from ray.data._internal.util import _default_batch_format from ray.data.block import DataBatch +from ray.data.context import DatasetContext from ray.data.dataset_iterator import DatasetIterator from ray.data._internal.block_batching import batch_block_refs @@ -24,6 +25,7 @@ def __init__( base_dataset: "Dataset", ): self._base_dataset = base_dataset + self._base_context = DatasetContext.get_current() def __repr__(self) -> str: return f"DatasetIterator({self._base_dataset})" @@ -40,6 +42,8 @@ def iter_batches( _collate_fn: Optional[Callable[[DataBatch], Any]] = None, ) -> Iterator[DataBatch]: + DatasetContext._set_current(self._base_context) + ds = self._base_dataset block_iterator, stats, executor = ds._plan.execute_to_iterator() ds._current_executor = executor @@ -85,5 +89,5 @@ def __getattr__(self, name): ) return getattr(self._base_dataset, name) - else: - return super().__getattr__(name) + + raise AttributeError() diff --git a/python/ray/data/_internal/execution/interfaces.py b/python/ray/data/_internal/execution/interfaces.py index 6b52c90df210..db227afe4438 100644 --- a/python/ray/data/_internal/execution/interfaces.py +++ b/python/ray/data/_internal/execution/interfaces.py @@ -448,6 +448,9 @@ def get_next(self, output_split_idx: Optional[int] = None) -> RefBundle: Args: output_split_idx: The output split index to get results for. This arg is only allowed for iterators created by `Dataset.streaming_split()`. + + Raises: + StopIteration if there are no more outputs to return. """ if output_split_idx is not None: raise NotImplementedError() diff --git a/python/ray/data/_internal/progress_bar.py b/python/ray/data/_internal/progress_bar.py index 78ea30cd0bd7..edcc132aa44e 100644 --- a/python/ray/data/_internal/progress_bar.py +++ b/python/ray/data/_internal/progress_bar.py @@ -46,7 +46,7 @@ class ProgressBar: def __init__(self, name: str, total: int, position: int = 0): self._desc = name - if not _enabled or threading.current_thread() is not threading.main_thread(): + if not _enabled: self._bar = None elif tqdm: ctx = ray.data.context.DatasetContext.get_current() diff --git a/python/ray/data/_internal/stream_split_dataset_iterator.py b/python/ray/data/_internal/stream_split_dataset_iterator.py index 661d2bf1c3a8..f4a37c9eb0b2 100644 --- a/python/ray/data/_internal/stream_split_dataset_iterator.py +++ b/python/ray/data/_internal/stream_split_dataset_iterator.py @@ -1,6 +1,7 @@ import copy import logging import sys +import time import threading from typing import ( List, @@ -27,6 +28,7 @@ from ray.data._internal.execution.operators.output_splitter import OutputSplitter from ray.data._internal.execution.interfaces import NodeIdStr, RefBundle from ray.types import ObjectRef +from ray.util.debug import log_once from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy if TYPE_CHECKING: @@ -41,6 +43,9 @@ logger = logging.getLogger(__name__) +BLOCKED_CLIENT_WARN_TIMEOUT = 30 + + class StreamSplitDatasetIterator(DatasetIterator): """Implements a collection of iterators over a shared data stream.""" @@ -93,15 +98,20 @@ def iter_batches( """Implements DatasetIterator.""" def gen_blocks() -> Iterator[ObjectRef[Block]]: + cur_epoch = ray.get( + self._coord_actor.start_epoch.remote(self._output_split_idx) + ) future: ObjectRef[ Optional[ObjectRef[Block]] - ] = self._coord_actor.get.remote(self._output_split_idx) + ] = self._coord_actor.get.remote(cur_epoch, self._output_split_idx) while True: block_ref: Optional[ObjectRef[Block]] = ray.get(future) if not block_ref: break else: - future = self._coord_actor.get.remote(self._output_split_idx) + future = self._coord_actor.get.remote( + cur_epoch, self._output_split_idx + ) yield block_ref yield from batch_block_refs( @@ -154,29 +164,54 @@ def __init__( self._n = n self._equal = equal self._locality_hints = locality_hints - self._finished = False self._lock = threading.RLock() + # Guarded by self._lock. self._next_bundle: Dict[int, RefBundle] = {} + self._unfinished_clients_in_epoch = n + self._cur_epoch = -1 - executor = StreamingExecutor(copy.deepcopy(ctx.execution_options)) + def gen_epochs(): + while True: + executor = StreamingExecutor(copy.deepcopy(ctx.execution_options)) - def add_split_op(dag): - return OutputSplitter(dag, n, equal, locality_hints) + def add_split_op(dag): + return OutputSplitter(dag, n, equal, locality_hints) - self._output_iterator = execute_to_legacy_bundle_iterator( - executor, - dataset._plan, - True, - dataset._plan._dataset_uuid, - dag_rewrite=add_split_op, - ) + output_iterator = execute_to_legacy_bundle_iterator( + executor, + dataset._plan, + True, + dataset._plan._dataset_uuid, + dag_rewrite=add_split_op, + ) + yield output_iterator + + self._next_epoch = gen_epochs() + self._output_iterator = None + + def start_epoch(self, split_idx: int) -> str: + """Called to start an epoch. - def get(self, output_split_idx: int) -> Optional[ObjectRef[Block]]: + Returns: + UUID for the epoch, which must be used when accessing results via get(). + """ + + # Wait for all clients to arrive at the barrier before starting a new epoch. + epoch_id = self._barrier(split_idx) + return epoch_id + + def get(self, epoch_id: int, output_split_idx: int) -> Optional[ObjectRef[Block]]: """Blocking get operation. This is intended to be called concurrently from multiple clients. """ + + if epoch_id != self._cur_epoch: + raise ValueError( + "Invalid iterator: the datastream has moved on to another epoch." + ) + try: # Ensure there is at least one bundle. with self._lock: @@ -201,3 +236,37 @@ def get(self, output_split_idx: int) -> Optional[ObjectRef[Block]]: return block except StopIteration: return None + + def _barrier(self, split_idx: int) -> int: + """Arrive and block until the start of the given epoch.""" + + # Decrement and await all clients to arrive here. + with self._lock: + starting_epoch = self._cur_epoch + self._unfinished_clients_in_epoch -= 1 + + start_time = time.time() + while ( + self._cur_epoch == starting_epoch and self._unfinished_clients_in_epoch != 0 + ): + if time.time() - start_time > BLOCKED_CLIENT_WARN_TIMEOUT: + if log_once(f"stream_split_blocked_{split_idx}_{starting_epoch}"): + logger.warning( + f"StreamSplitDatasetIterator(epoch={starting_epoch}, " + f"split={split_idx}) blocked waiting on other clients " + f"for more than {BLOCKED_CLIENT_WARN_TIMEOUT}s. All " + "clients must read from the DatasetIterator splits at " + "the same time. This warning will not be printed again " + "for this epoch." + ) + time.sleep(0.1) + + # Advance to the next epoch. + with self._lock: + if self._cur_epoch == starting_epoch: + self._cur_epoch += 1 + self._unfinished_clients_in_epoch = self._n + self._output_iterator = next(self._next_epoch) + + assert self._output_iterator is not None + return starting_epoch + 1 diff --git a/python/ray/data/tests/test_streaming_integration.py b/python/ray/data/tests/test_streaming_integration.py index ffdfa4d9642a..00e93c131d1a 100644 --- a/python/ray/data/tests/test_streaming_integration.py +++ b/python/ray/data/tests/test_streaming_integration.py @@ -92,15 +92,29 @@ def run(self): def test_streaming_split_e2e(ray_start_10_cpus_shared): def get_lengths(*iterators, use_iter_batches=True): lengths = [] - for it in iterators: - x = 0 - if use_iter_batches: - for batch in it.iter_batches(): - x += len(batch) - else: - for _ in it.iter_rows(): - x += 1 - lengths.append(x) + + class Runner(threading.Thread): + def __init__(self, it): + self.it = it + super().__init__() + + def run(self): + it = self.it + x = 0 + if use_iter_batches: + for batch in it.iter_batches(): + x += len(batch) + else: + for _ in it.iter_rows(): + x += 1 + lengths.append(x) + + runners = [Runner(it) for it in iterators] + for r in runners: + r.start() + for r in runners: + r.join() + lengths.sort() return lengths @@ -109,35 +123,91 @@ def get_lengths(*iterators, use_iter_batches=True): i1, i2, ) = ds.streaming_split(2, equal=True) - lengths = get_lengths(i1, i2) - assert lengths == [500, 500], lengths + for _ in range(2): + lengths = get_lengths(i1, i2) + assert lengths == [500, 500], lengths ds = ray.data.range(1) ( i1, i2, ) = ds.streaming_split(2, equal=True) - lengths = get_lengths(i1, i2) - assert lengths == [0, 0], lengths + for _ in range(2): + lengths = get_lengths(i1, i2) + assert lengths == [0, 0], lengths ds = ray.data.range(1) ( i1, i2, ) = ds.streaming_split(2, equal=False) - lengths = get_lengths(i1, i2) - assert lengths == [0, 1], lengths + for _ in range(2): + lengths = get_lengths(i1, i2) + assert lengths == [0, 1], lengths ds = ray.data.range(1000, parallelism=10) for equal_split, use_iter_batches in itertools.product( [True, False], [True, False] ): i1, i2, i3 = ds.streaming_split(3, equal=equal_split) - lengths = get_lengths(i1, i2, i3, use_iter_batches=use_iter_batches) - if equal_split: - assert lengths == [333, 333, 333], lengths - else: - assert lengths == [300, 300, 400], lengths + for _ in range(2): + lengths = get_lengths(i1, i2, i3, use_iter_batches=use_iter_batches) + if equal_split: + assert lengths == [333, 333, 333], lengths + else: + assert lengths == [300, 300, 400], lengths + + +def test_streaming_split_barrier(ray_start_10_cpus_shared): + ds = ray.data.range(20, parallelism=20) + ( + i1, + i2, + ) = ds.streaming_split(2, equal=True) + + @ray.remote + def consume(x, times): + i = 0 + for _ in range(times): + for _ in x.iter_rows(): + i += 1 + return i + + # Succeeds. + ray.get([consume.remote(i1, 2), consume.remote(i2, 2)]) + ray.get([consume.remote(i1, 2), consume.remote(i2, 2)]) + ray.get([consume.remote(i1, 2), consume.remote(i2, 2)]) + + # Blocks forever since one reader is stalled. + with pytest.raises(ray.exceptions.GetTimeoutError): + ray.get([consume.remote(i1, 2), consume.remote(i2, 1)], timeout=3) + + +def test_streaming_split_invalid_iterator(ray_start_10_cpus_shared): + ds = ray.data.range(20, parallelism=20) + ( + i1, + i2, + ) = ds.streaming_split(2, equal=True) + + @ray.remote + def consume(x, times): + i = 0 + for _ in range(times): + for _ in x.iter_rows(): + i += 1 + return i + + # InvalidIterator error from too many concurrent readers. + with pytest.raises(ValueError): + ray.get( + [ + consume.remote(i1, 4), + consume.remote(i2, 4), + consume.remote(i1, 4), + consume.remote(i2, 4), + ] + ) def test_e2e_option_propagation(ray_start_10_cpus_shared, restore_dataset_context): diff --git a/python/ray/experimental/tqdm_ray.py b/python/ray/experimental/tqdm_ray.py index 5428908df894..14ac12f40cb2 100644 --- a/python/ray/experimental/tqdm_ray.py +++ b/python/ray/experimental/tqdm_ray.py @@ -28,6 +28,15 @@ _manager: Optional["_BarManager"] = None +def safe_print(*args, **kwargs): + """Use this as an alternative to `print` that will not corrupt tqdm output.""" + try: + instance().hide_bars() + print(*args, **kwargs) + finally: + instance().unhide_bars() + + class tqdm: """Experimental: Ray distributed tqdm implementation. @@ -76,7 +85,9 @@ def update(self, n=1): def close(self): """Implements tqdm.tqdm.close.""" self._closed = True - self._dump_state() + # Don't bother if ray is shutdown (in __del__ hook). + if ray is not None: + self._dump_state() def _dump_state(self) -> None: if ray._private.worker.global_worker.mode == ray.WORKER_MODE: diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 0d9d8f72ce82..8d63fc069c8e 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -16,6 +16,7 @@ import ray from ray._private.dict import flatten_dict +from ray.experimental.tqdm_ray import safe_print from ray.air.util.node import _force_on_current_node from ray.tune.callback import Callback from ray.tune.logger import pretty_print @@ -707,7 +708,7 @@ def __init__( ) def _print(self, msg: str): - print(msg) + safe_print(msg) def report(self, trials: List[Trial], done: bool, *sys_info: Dict): self._print(self._progress_str(trials, done, *sys_info)) @@ -1313,7 +1314,7 @@ def __init__( self._display_handle = None def _print(self, msg: str): - print(msg) + safe_print(msg) def on_trial_result( self,