Skip to content

Commit

Permalink
[data] [streaming] Enable actor fault tolerance by default for stream…
Browse files Browse the repository at this point in the history
…ing map operator (ray-project#33906)
  • Loading branch information
ericl authored Apr 4, 2023
1 parent a66a385 commit 6e1828e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@ def _apply_default_remote_args(ray_remote_args: Dict[str, Any]) -> Dict[str, Any
ray_remote_args["scheduling_strategy"] = "SPREAD"
else:
ray_remote_args["scheduling_strategy"] = ctx.scheduling_strategy
# Enable actor fault tolerance by default, with infinite actor recreations and
# up to N retries per task. The user can customize this in map_batches via
# extra kwargs (e.g., map_batches(..., max_restarts=0) to disable).
if "max_restarts" not in ray_remote_args:
ray_remote_args["max_restarts"] = -1
if (
"max_task_retries" not in ray_remote_args
and ray_remote_args.get("max_restarts") != 0
):
ray_remote_args["max_task_retries"] = 5
return ray_remote_args


Expand Down
26 changes: 26 additions & 0 deletions python/ray/data/tests/test_streaming_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import random
import pytest
import threading
import time
Expand Down Expand Up @@ -449,6 +450,31 @@ def test_can_pickle(ray_start_10_cpus_shared, restore_dataset_context):
assert ds2.count() == 1000000


def test_streaming_fault_tolerance(ray_start_10_cpus_shared, restore_dataset_context):
DatasetContext.get_current().new_execution_backend = True
DatasetContext.get_current().use_streaming_executor = True

def f(x):
import os

if random.random() > 0.9:
print("force exit")
os._exit(1)
return x

# Test recover.
base = ray.data.range(1000, parallelism=100)
ds1 = base.map_batches(
f, compute=ray.data.ActorPoolStrategy(4, 4), max_task_retries=999
)
ds1.take_all()

# Test disabling fault tolerance.
ds2 = base.map_batches(f, compute=ray.data.ActorPoolStrategy(4, 4), max_restarts=0)
with pytest.raises(ray.exceptions.RayActorError):
ds2.take_all()


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 6e1828e

Please sign in to comment.