Skip to content

Commit

Permalink
[core][accelerated DAGs] Make DAG teardown blocking and fix bug durin…
Browse files Browse the repository at this point in the history
…g close (ray-project#45099)

## Why are these changes needed?

Contains a few fixes related to DAG teardown:
- Removes an unnecessary `.close()` call that would error if the DAG has
a single output (instead of a MultiOutputNode)
- Makes `dag.teardown()` blocking to ensure that actors can be reused
after the teardown call returns.
- Makes DAG teardown in `__del__` asynchronous. if synchronous, this can
hang the driver upon shutdown. I'm not exactly sure why but I believe
this happens if the CoreWorker is shut down before `dag.teardown()` is
complete.

---------
Signed-off-by: Stephanie Wang <[email protected]>
  • Loading branch information
stephanie-wang authored May 3, 2024
1 parent 36749b8 commit 4937ac3
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 26 deletions.
80 changes: 56 additions & 24 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __init__(

# ObjectRef for each worker's task. The task is an infinite loop that
# repeatedly executes the method specified in the DAG.
self.worker_task_refs: List["ray.ObjectRef"] = []
self.worker_task_refs: Dict["ray.actor.ActorHandle", "ray.ObjectRef"] = {}
# Set of actors present in the DAG.
self.actor_refs = set()

Expand Down Expand Up @@ -525,14 +525,14 @@ def _get_node_id(self):

# Assign the task with the correct input and output buffers.
worker_fn = task.dag_node._get_remote_method("__ray_call__")
self.worker_task_refs.append(
worker_fn.options(concurrency_group="_ray_system").remote(
do_exec_compiled_task,
resolved_args,
task.dag_node.get_method_name(),
output_wrapper_fn=task.output_wrapper_fn,
has_type_hints=bool(self._type_hints),
)
self.worker_task_refs[
task.dag_node._get_actor_handle()
] = worker_fn.options(concurrency_group="_ray_system").remote(
do_exec_compiled_task,
resolved_args,
task.dag_node.get_method_name(),
output_wrapper_fn=task.output_wrapper_fn,
has_type_hints=bool(self._type_hints),
)

# Wrapper function for inputs provided to dag.execute().
Expand Down Expand Up @@ -586,9 +586,37 @@ def __init__(self):
super().__init__(daemon=True)
self.in_teardown = False

def teardown(self):
def wait_teardown(self):
for actor, ref in outer.worker_task_refs.items():
timeout = False
try:
ray.get(ref, timeout=10)
except ray.exceptions.GetTimeoutError:
logger.warn(
f"Compiled DAG actor {actor} is still running 10s "
"after teardown(). Teardown may hang."
)
timeout = True
except Exception:
# We just want to check that the task has finished so
# we don't care if the actor task ended in an
# exception.
pass

if not timeout:
continue

try:
ray.get(ref)
except Exception:
pass

def teardown(self, wait: bool):
if self.in_teardown:
if wait:
self.wait_teardown()
return

logger.info("Tearing down compiled DAG")

outer._dag_submitter.close()
Expand All @@ -604,24 +632,20 @@ def teardown(self):
except Exception:
logger.exception("Error cancelling worker task")
pass
logger.info("Waiting for worker tasks to exit")
for ref in outer.worker_task_refs:
try:
ray.get(ref)
except Exception:
pass
logger.info("Teardown complete")

if wait:
logger.info("Waiting for worker tasks to exit")
self.wait_teardown()
logger.info("Teardown complete")

def run(self):
try:
ray.get(outer.worker_task_refs)
ray.get(list(outer.worker_task_refs.values()))
except Exception as e:
logger.debug(f"Handling exception from worker tasks: {e}")
if self.in_teardown:
return
for output_channel in outer.dag_output_channels:
output_channel.close()
self.teardown()
self.teardown(wait=True)

monitor = Monitor()
monitor.start()
Expand Down Expand Up @@ -701,13 +725,21 @@ async def execute_async(
return AwaitableDAGOutput(fut, self._dag_output_fetcher)

def teardown(self):
"""Teardown and cancel all worker tasks for this DAG."""
"""Teardown and cancel all actor tasks for this DAG. After this
function returns, the actors should be available to execute new tasks
or compile a new DAG."""
monitor = getattr(self, "_monitor", None)
if monitor is not None:
monitor.teardown()
monitor.teardown(wait=True)

def __del__(self):
self.teardown()
monitor = getattr(self, "_monitor", None)
if monitor is not None:
# Teardown asynchronously.
# NOTE(swang): Somehow, this can get called after the CoreWorker
# has already been destructed, so it is not safe to block in
# ray.get.
monitor.teardown(wait=False)


@DeveloperAPI
Expand Down
48 changes: 46 additions & 2 deletions python/ray/dag/tests/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ def f(x):


def test_dag_fault_tolerance(ray_start_regular_shared):
actors = [Actor.remote(0, fail_after=100, sys_exit=False) for _ in range(4)]
actors = [
Actor.remote(0, fail_after=100 if i == 0 else None, sys_exit=False)
for i in range(4)
]
with InputNode() as i:
out = [a.inc.bind(i) for a in actors]
dag = MultiOutputNode(out)
Expand All @@ -235,9 +238,30 @@ def test_dag_fault_tolerance(ray_start_regular_shared):

compiled_dag.teardown()

# All actors are still alive.
ray.get([actor.echo.remote("hello") for actor in actors])

# Remaining actors can be reused.
actors.pop(0)
with InputNode() as i:
out = [a.inc.bind(i) for a in actors]
dag = MultiOutputNode(out)

compiled_dag = dag.experimental_compile()
for i in range(100):
output_channels = compiled_dag.execute(1)
# TODO(swang): Replace with fake ObjectRef.
output_channels.begin_read()
output_channels.end_read()

compiled_dag.teardown()


def test_dag_fault_tolerance_sys_exit(ray_start_regular_shared):
actors = [Actor.remote(0, fail_after=100, sys_exit=True) for _ in range(4)]
actors = [
Actor.remote(0, fail_after=100 if i == 0 else None, sys_exit=True)
for i in range(4)
]
with InputNode() as i:
out = [a.inc.bind(i) for a in actors]
dag = MultiOutputNode(out)
Expand All @@ -257,6 +281,26 @@ def test_dag_fault_tolerance_sys_exit(ray_start_regular_shared):
output_channels.begin_read()
output_channels.end_read()

# Remaining actors are still alive.
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actors[0].echo.remote("hello"))
actors.pop(0)
ray.get([actor.echo.remote("hello") for actor in actors])

# Remaining actors can be reused.
with InputNode() as i:
out = [a.inc.bind(i) for a in actors]
dag = MultiOutputNode(out)

compiled_dag = dag.experimental_compile()
for i in range(100):
output_channels = compiled_dag.execute(1)
# TODO(swang): Replace with fake ObjectRef.
output_channels.begin_read()
output_channels.end_read()

compiled_dag.teardown()


def test_dag_teardown_while_running(ray_start_regular_shared):
a = Actor.remote(0)
Expand Down

0 comments on commit 4937ac3

Please sign in to comment.