Skip to content

Commit

Permalink
[data] [docs] Improve the streaming_split pydoc (ray-project#33424)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Mar 22, 2023
1 parent 636a699 commit 83dc07a
Showing 1 changed file with 37 additions and 7 deletions.
44 changes: 37 additions & 7 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,17 +1159,44 @@ def streaming_split(
be used to read disjoint subsets of the dataset in parallel.
This method is the recommended way to consume Datasets from multiple processes
(e.g., for distributed training). It requires streaming execution mode.
(e.g., for distributed training), and requires streaming execution mode.
The returned iterators are Ray-serializable and can be freely passed to any
Ray task or actor.
Streaming split works by delegating the execution of this Dataset to a
coordinator actor. The coordinator pulls block references from the executed
stream, and divides those blocks among `n` output iterators. Iterators pull
blocks from the coordinator actor to return to their caller on `next`.
The returned iterators are also repeatable; each iteration will trigger a
new execution of the Dataset. There is an implicit barrier at the start of
each iteration, which means that `next` must be called on all iterators before
the iteration starts.
Warning: because iterators are pulling blocks from the same Dataset execution,
if one iterator falls behind other iterators may be stalled.
Examples:
>>> import ray
>>> ds = ray.data.range(1000000)
>>> it1, it2 = ds.streaming_split(2, equal=True)
>>> list(it1.iter_batches()) # doctest: +SKIP
>>> list(it2.iter_batches()) # doctest: +SKIP
>>> # Can consume from both iterators in parallel.
>>> @ray.remote
... def consume(it):
... for batch in it.iter_batches():
... print(batch)
>>> ray.get([consume.remote(it1), consume.remote(it2)]) # doctest: +SKIP
>>> # Can loop over the iterators multiple times (multiple epochs).
>>> @ray.remote
... def train(it):
... NUM_EPOCHS = 100
... for _ in range(NUM_EPOCHS):
... for batch in it.iter_batches():
... print(batch)
>>> ray.get([train.remote(it1), train.remote(it2)]) # doctest: +SKIP
>>> # ERROR: this will block waiting for a read on `it2` to start.
>>> ray.get(train.remote(it1)) # doctest: +SKIP
Args:
n: Number of output iterators to return.
Expand All @@ -1178,10 +1205,13 @@ def streaming_split(
slightly more or less rows than other, but no data will be dropped.
locality_hints: Specify the node ids corresponding to each iterator
location. Datasets will try to minimize data movement based on the
iterator output locations. This list must have length ``n``.
iterator output locations. This list must have length ``n``. You can
get the current node id of a task or actor by calling
``ray.get_runtime_context().get_node_id()``.
Returns:
The output iterator splits.
The output iterator splits. These iterators are Ray-serializable and can
be freely passed to any Ray task or actor.
"""
return StreamSplitDatasetIterator.create(self, n, equal, locality_hints)

Expand Down

0 comments on commit 83dc07a

Please sign in to comment.