Skip to content

Commit

Permalink
[data] [streaming] Support iterating over streaming_split() iterators…
Browse files Browse the repository at this point in the history
… multiple times (ray-project#33303)

In order to support Train integration, the iterators returned from Dataset.streaming_split() must be repeatable, as with the other iterators.
  • Loading branch information
ericl authored Mar 16, 2023
1 parent f1a1430 commit c93e8ae
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 40 deletions.
2 changes: 2 additions & 0 deletions doc/source/ray-observability/ray-logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ Limitations:
- Only a subset of tqdm functionality is supported. Refer to the ray_tqdm `implementation <https://github.com/ray-project/ray/blob/master/python/ray/experimental/tqdm_ray.py>`__ for more details.
- Performance may be poor if there are more than a couple thousand updates per second (updates are not batched).

Tip: To avoid `print` statements from the driver conflicting with tqdm output, use `ray.experimental.tqdm_ray.safe_print` instead.

How to set up loggers
~~~~~~~~~~~~~~~~~~~~~
When using ray, all of the tasks and actors are executed remotely in Ray's worker processes.
Expand Down
8 changes: 6 additions & 2 deletions python/ray/data/_internal/dataset_iterator_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ray.data._internal.util import _default_batch_format
from ray.data.block import DataBatch
from ray.data.context import DatasetContext
from ray.data.dataset_iterator import DatasetIterator
from ray.data._internal.block_batching import batch_block_refs

Expand All @@ -24,6 +25,7 @@ def __init__(
base_dataset: "Dataset",
):
self._base_dataset = base_dataset
self._base_context = DatasetContext.get_current()

def __repr__(self) -> str:
return f"DatasetIterator({self._base_dataset})"
Expand All @@ -40,6 +42,8 @@ def iter_batches(
_collate_fn: Optional[Callable[[DataBatch], Any]] = None,
) -> Iterator[DataBatch]:

DatasetContext._set_current(self._base_context)

ds = self._base_dataset
block_iterator, stats, executor = ds._plan.execute_to_iterator()
ds._current_executor = executor
Expand Down Expand Up @@ -85,5 +89,5 @@ def __getattr__(self, name):
)

return getattr(self._base_dataset, name)
else:
return super().__getattr__(name)

raise AttributeError()
3 changes: 3 additions & 0 deletions python/ray/data/_internal/execution/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ def get_next(self, output_split_idx: Optional[int] = None) -> RefBundle:
Args:
output_split_idx: The output split index to get results for. This arg is
only allowed for iterators created by `Dataset.streaming_split()`.
Raises:
StopIteration if there are no more outputs to return.
"""
if output_split_idx is not None:
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ProgressBar:

def __init__(self, name: str, total: int, position: int = 0):
self._desc = name
if not _enabled or threading.current_thread() is not threading.main_thread():
if not _enabled:
self._bar = None
elif tqdm:
ctx = ray.data.context.DatasetContext.get_current()
Expand Down
97 changes: 83 additions & 14 deletions python/ray/data/_internal/stream_split_dataset_iterator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import logging
import sys
import time
import threading
from typing import (
List,
Expand All @@ -27,6 +28,7 @@
from ray.data._internal.execution.operators.output_splitter import OutputSplitter
from ray.data._internal.execution.interfaces import NodeIdStr, RefBundle
from ray.types import ObjectRef
from ray.util.debug import log_once
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

if TYPE_CHECKING:
Expand All @@ -41,6 +43,9 @@
logger = logging.getLogger(__name__)


BLOCKED_CLIENT_WARN_TIMEOUT = 30


class StreamSplitDatasetIterator(DatasetIterator):
"""Implements a collection of iterators over a shared data stream."""

Expand Down Expand Up @@ -93,15 +98,20 @@ def iter_batches(
"""Implements DatasetIterator."""

def gen_blocks() -> Iterator[ObjectRef[Block]]:
cur_epoch = ray.get(
self._coord_actor.start_epoch.remote(self._output_split_idx)
)
future: ObjectRef[
Optional[ObjectRef[Block]]
] = self._coord_actor.get.remote(self._output_split_idx)
] = self._coord_actor.get.remote(cur_epoch, self._output_split_idx)
while True:
block_ref: Optional[ObjectRef[Block]] = ray.get(future)
if not block_ref:
break
else:
future = self._coord_actor.get.remote(self._output_split_idx)
future = self._coord_actor.get.remote(
cur_epoch, self._output_split_idx
)
yield block_ref

yield from batch_block_refs(
Expand Down Expand Up @@ -154,29 +164,54 @@ def __init__(
self._n = n
self._equal = equal
self._locality_hints = locality_hints
self._finished = False
self._lock = threading.RLock()

# Guarded by self._lock.
self._next_bundle: Dict[int, RefBundle] = {}
self._unfinished_clients_in_epoch = n
self._cur_epoch = -1

executor = StreamingExecutor(copy.deepcopy(ctx.execution_options))
def gen_epochs():
while True:
executor = StreamingExecutor(copy.deepcopy(ctx.execution_options))

def add_split_op(dag):
return OutputSplitter(dag, n, equal, locality_hints)
def add_split_op(dag):
return OutputSplitter(dag, n, equal, locality_hints)

self._output_iterator = execute_to_legacy_bundle_iterator(
executor,
dataset._plan,
True,
dataset._plan._dataset_uuid,
dag_rewrite=add_split_op,
)
output_iterator = execute_to_legacy_bundle_iterator(
executor,
dataset._plan,
True,
dataset._plan._dataset_uuid,
dag_rewrite=add_split_op,
)
yield output_iterator

self._next_epoch = gen_epochs()
self._output_iterator = None

def start_epoch(self, split_idx: int) -> str:
"""Called to start an epoch.
def get(self, output_split_idx: int) -> Optional[ObjectRef[Block]]:
Returns:
UUID for the epoch, which must be used when accessing results via get().
"""

# Wait for all clients to arrive at the barrier before starting a new epoch.
epoch_id = self._barrier(split_idx)
return epoch_id

def get(self, epoch_id: int, output_split_idx: int) -> Optional[ObjectRef[Block]]:
"""Blocking get operation.
This is intended to be called concurrently from multiple clients.
"""

if epoch_id != self._cur_epoch:
raise ValueError(
"Invalid iterator: the datastream has moved on to another epoch."
)

try:
# Ensure there is at least one bundle.
with self._lock:
Expand All @@ -201,3 +236,37 @@ def get(self, output_split_idx: int) -> Optional[ObjectRef[Block]]:
return block
except StopIteration:
return None

def _barrier(self, split_idx: int) -> int:
"""Arrive and block until the start of the given epoch."""

# Decrement and await all clients to arrive here.
with self._lock:
starting_epoch = self._cur_epoch
self._unfinished_clients_in_epoch -= 1

start_time = time.time()
while (
self._cur_epoch == starting_epoch and self._unfinished_clients_in_epoch != 0
):
if time.time() - start_time > BLOCKED_CLIENT_WARN_TIMEOUT:
if log_once(f"stream_split_blocked_{split_idx}_{starting_epoch}"):
logger.warning(
f"StreamSplitDatasetIterator(epoch={starting_epoch}, "
f"split={split_idx}) blocked waiting on other clients "
f"for more than {BLOCKED_CLIENT_WARN_TIMEOUT}s. All "
"clients must read from the DatasetIterator splits at "
"the same time. This warning will not be printed again "
"for this epoch."
)
time.sleep(0.1)

# Advance to the next epoch.
with self._lock:
if self._cur_epoch == starting_epoch:
self._cur_epoch += 1
self._unfinished_clients_in_epoch = self._n
self._output_iterator = next(self._next_epoch)

assert self._output_iterator is not None
return starting_epoch + 1
110 changes: 90 additions & 20 deletions python/ray/data/tests/test_streaming_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,29 @@ def run(self):
def test_streaming_split_e2e(ray_start_10_cpus_shared):
def get_lengths(*iterators, use_iter_batches=True):
lengths = []
for it in iterators:
x = 0
if use_iter_batches:
for batch in it.iter_batches():
x += len(batch)
else:
for _ in it.iter_rows():
x += 1
lengths.append(x)

class Runner(threading.Thread):
def __init__(self, it):
self.it = it
super().__init__()

def run(self):
it = self.it
x = 0
if use_iter_batches:
for batch in it.iter_batches():
x += len(batch)
else:
for _ in it.iter_rows():
x += 1
lengths.append(x)

runners = [Runner(it) for it in iterators]
for r in runners:
r.start()
for r in runners:
r.join()

lengths.sort()
return lengths

Expand All @@ -109,35 +123,91 @@ def get_lengths(*iterators, use_iter_batches=True):
i1,
i2,
) = ds.streaming_split(2, equal=True)
lengths = get_lengths(i1, i2)
assert lengths == [500, 500], lengths
for _ in range(2):
lengths = get_lengths(i1, i2)
assert lengths == [500, 500], lengths

ds = ray.data.range(1)
(
i1,
i2,
) = ds.streaming_split(2, equal=True)
lengths = get_lengths(i1, i2)
assert lengths == [0, 0], lengths
for _ in range(2):
lengths = get_lengths(i1, i2)
assert lengths == [0, 0], lengths

ds = ray.data.range(1)
(
i1,
i2,
) = ds.streaming_split(2, equal=False)
lengths = get_lengths(i1, i2)
assert lengths == [0, 1], lengths
for _ in range(2):
lengths = get_lengths(i1, i2)
assert lengths == [0, 1], lengths

ds = ray.data.range(1000, parallelism=10)
for equal_split, use_iter_batches in itertools.product(
[True, False], [True, False]
):
i1, i2, i3 = ds.streaming_split(3, equal=equal_split)
lengths = get_lengths(i1, i2, i3, use_iter_batches=use_iter_batches)
if equal_split:
assert lengths == [333, 333, 333], lengths
else:
assert lengths == [300, 300, 400], lengths
for _ in range(2):
lengths = get_lengths(i1, i2, i3, use_iter_batches=use_iter_batches)
if equal_split:
assert lengths == [333, 333, 333], lengths
else:
assert lengths == [300, 300, 400], lengths


def test_streaming_split_barrier(ray_start_10_cpus_shared):
ds = ray.data.range(20, parallelism=20)
(
i1,
i2,
) = ds.streaming_split(2, equal=True)

@ray.remote
def consume(x, times):
i = 0
for _ in range(times):
for _ in x.iter_rows():
i += 1
return i

# Succeeds.
ray.get([consume.remote(i1, 2), consume.remote(i2, 2)])
ray.get([consume.remote(i1, 2), consume.remote(i2, 2)])
ray.get([consume.remote(i1, 2), consume.remote(i2, 2)])

# Blocks forever since one reader is stalled.
with pytest.raises(ray.exceptions.GetTimeoutError):
ray.get([consume.remote(i1, 2), consume.remote(i2, 1)], timeout=3)


def test_streaming_split_invalid_iterator(ray_start_10_cpus_shared):
ds = ray.data.range(20, parallelism=20)
(
i1,
i2,
) = ds.streaming_split(2, equal=True)

@ray.remote
def consume(x, times):
i = 0
for _ in range(times):
for _ in x.iter_rows():
i += 1
return i

# InvalidIterator error from too many concurrent readers.
with pytest.raises(ValueError):
ray.get(
[
consume.remote(i1, 4),
consume.remote(i2, 4),
consume.remote(i1, 4),
consume.remote(i2, 4),
]
)


def test_e2e_option_propagation(ray_start_10_cpus_shared, restore_dataset_context):
Expand Down
13 changes: 12 additions & 1 deletion python/ray/experimental/tqdm_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
_manager: Optional["_BarManager"] = None


def safe_print(*args, **kwargs):
"""Use this as an alternative to `print` that will not corrupt tqdm output."""
try:
instance().hide_bars()
print(*args, **kwargs)
finally:
instance().unhide_bars()


class tqdm:
"""Experimental: Ray distributed tqdm implementation.
Expand Down Expand Up @@ -76,7 +85,9 @@ def update(self, n=1):
def close(self):
"""Implements tqdm.tqdm.close."""
self._closed = True
self._dump_state()
# Don't bother if ray is shutdown (in __del__ hook).
if ray is not None:
self._dump_state()

def _dump_state(self) -> None:
if ray._private.worker.global_worker.mode == ray.WORKER_MODE:
Expand Down
Loading

0 comments on commit c93e8ae

Please sign in to comment.