Skip to content

Commit

Permalink
[Core] Support the basic actor cancelation for async actors. (ray-pro…
Browse files Browse the repository at this point in the history
…ject#38466)

This PR allows to cancel actor tasks. The details are written in this doc https://docs.google.com/document/d/12avlF2OoFLs8lfC18i9ATXWrapxH1vFMy2UZU7L0xHI/edit (the doc is still in progress).
  • Loading branch information
rkooo567 authored Aug 21, 2023
1 parent 7a8b6a1 commit 03ae779
Show file tree
Hide file tree
Showing 36 changed files with 1,047 additions and 103 deletions.
6 changes: 5 additions & 1 deletion python/ray/_private/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,11 @@ def _deserialize_object(self, data, metadata, object_ref):
elif error_type == ErrorType.Value("LOCAL_RAYLET_DIED"):
return LocalRayletDiedError()
elif error_type == ErrorType.Value("TASK_CANCELLED"):
return TaskCancelledError()
error_message = ""
if data:
error_info = self._deserialize_error_info(data, metadata_fields)
error_message = error_info.error_message
return TaskCancelledError(error_message=error_message)
elif error_type == ErrorType.Value("OBJECT_LOST"):
return ObjectLostError(
object_ref.hex(), object_ref.owner_address(), object_ref.call_site()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2871,7 +2871,7 @@ def cancel(object_ref: "ray.ObjectRef", *, force: bool = False, recursive: bool

if not isinstance(object_ref, ray.ObjectRef):
raise TypeError(
"ray.cancel() only supported for non-actor object refs. "
"ray.cancel() only supported for object refs. "
f"For actors, try ray.kill(). Got: {type(object_ref)}."
)
return worker.core_worker.cancel_task(object_ref, force, recursive)
Expand Down
2 changes: 2 additions & 0 deletions python/ray/_raylet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ cdef class CoreWorker:
object eventloop_for_default_cg
object thread_for_default_cg
object fd_to_cgname_dict
object task_id_to_future_lock
dict task_id_to_future
object thread_pool_for_async_event_loop

cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
Expand Down
66 changes: 61 additions & 5 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ from typing import (
Generator,
AsyncGenerator
)

import concurrent
from concurrent.futures import ThreadPoolExecutor

from libc.stdint cimport (
Expand Down Expand Up @@ -1533,7 +1535,7 @@ cdef void execute_task(
else:
return core_worker.run_async_func_or_coro_in_event_loop(
async_function, function_descriptor,
name_of_concurrency_group_to_execute, actor,
name_of_concurrency_group_to_execute, task_id, actor,
*arguments, **kwarguments)

return function(actor, *arguments, **kwarguments)
Expand All @@ -1559,7 +1561,7 @@ cdef void execute_task(
metadata_pairs, object_refs))
args = core_worker.run_async_func_or_coro_in_event_loop(
deserialize_args, function_descriptor,
name_of_concurrency_group_to_execute)
name_of_concurrency_group_to_execute, None)
else:
# Defer task cancellation (SIGINT) until after the task argument
# deserialization context has been left.
Expand Down Expand Up @@ -1640,7 +1642,8 @@ cdef void execute_task(
core_worker.run_async_func_or_coro_in_event_loop(
execute_streaming_generator_async(context),
function_descriptor,
name_of_concurrency_group_to_execute)
name_of_concurrency_group_to_execute,
task_id)
else:
execute_streaming_generator_sync(context)

Expand Down Expand Up @@ -2188,6 +2191,28 @@ cdef void delete_spilled_objects_handler(
job_id=None)


cdef void cancel_async_task(
const CTaskID &c_task_id,
const CRayFunction &ray_function,
const c_string c_name_of_concurrency_group_to_execute) nogil:
with gil:
function_descriptor = CFunctionDescriptorToPython(
ray_function.GetFunctionDescriptor())
name_of_concurrency_group_to_execute = \
c_name_of_concurrency_group_to_execute.decode("ascii")
task_id = TaskID(c_task_id.Binary())

worker = ray._private.worker.global_worker
eventloop, _ = worker.core_worker.get_event_loop(
function_descriptor, name_of_concurrency_group_to_execute)
future = worker.core_worker.get_queued_future(task_id)
if future is not None:
eventloop.call_soon_threadsafe(future.cancel)
# else, the task is already finished. If the task
# wasn't finished (task is queued on a client or server side),
# this method shouldn't have been called.


cdef void unhandled_exception_handler(const CRayObject& error) nogil:
with gil:
worker = ray._private.worker.global_worker
Expand Down Expand Up @@ -2945,6 +2970,7 @@ cdef class CoreWorker:
options.restore_spilled_objects = restore_spilled_objects_handler
options.delete_spilled_objects = delete_spilled_objects_handler
options.unhandled_exception_handler = unhandled_exception_handler
options.cancel_async_task = cancel_async_task
options.get_lang_stack = get_py_stack
options.is_local_mode = local_mode
options.kill_main = kill_main_task
Expand All @@ -2965,6 +2991,8 @@ cdef class CoreWorker:
self.fd_to_cgname_dict = None
self.eventloop_for_default_cg = None
self.current_runtime_env = None
self.task_id_to_future_lock = threading.Lock()
self.task_id_to_future = {}
self.thread_pool_for_async_event_loop = None

def shutdown(self):
Expand Down Expand Up @@ -3721,7 +3749,8 @@ cdef class CoreWorker:
CObjectID c_object_id = object_ref.native()
CRayStatus status = CRayStatus.OK()

status = CCoreWorkerProcess.GetCoreWorker().CancelTask(
with nogil:
status = CCoreWorkerProcess.GetCoreWorker().CancelTask(
c_object_id, force_kill, recursive)

if not status.ok():
Expand Down Expand Up @@ -4148,6 +4177,7 @@ cdef class CoreWorker:
func_or_coro: Union[Callable[[Any, Any], Awaitable[Any]], Awaitable],
function_descriptor: FunctionDescriptor,
specified_cgname: str,
task_id: Optional[TaskID],
*args,
**kwargs):
"""Run the async function or coroutine to the event loop.
Expand All @@ -4157,6 +4187,10 @@ cdef class CoreWorker:
func_or_coro: Async function (not a generator) or awaitable objects.
function_descriptor: The function descriptor.
specified_cgname: The name of a concurrent group.
task_id: The task ID to track the future. If None is provided
the future is not tracked with a task ID.
(e.g., When we deserialize the arguments, we don't want to
track the task_id -> future mapping).
args: The arguments for the async function.
kwargs: The keyword arguments for the async function.
"""
Expand All @@ -4181,11 +4215,24 @@ cdef class CoreWorker:
coroutine = func_or_coro(*args, **kwargs)

future = asyncio.run_coroutine_threadsafe(coroutine, eventloop)
if task_id:
with self.task_id_to_future_lock:
self.task_id_to_future[task_id] = asyncio.wrap_future(
future, loop=eventloop)

future.add_done_callback(lambda _: event.Notify())
with nogil:
(CCoreWorkerProcess.GetCoreWorker()
.YieldCurrentFiber(event))
return future.result()
try:
result = future.result()
except concurrent.futures.CancelledError:
raise TaskCancelledError(task_id)
finally:
if task_id:
with self.task_id_to_future_lock:
self.task_id_to_future.pop(task_id)
return result

def stop_and_join_asyncio_threads_if_exist(self):
event_loops = []
Expand Down Expand Up @@ -4215,6 +4262,15 @@ cdef class CoreWorker:
return (CCoreWorkerProcess.GetCoreWorker().GetWorkerContext()
.CurrentActorMaxConcurrency())

def get_queued_future(self, task_id: Optional[TaskID]) -> asyncio.Future:
"""Get a asyncio.Future that's queued in the event loop."""
with self.task_id_to_future_lock:
return self.task_id_to_future.get(task_id)

def get_task_id_to_future(self):
# Testing-only
return self.task_id_to_future

def get_current_runtime_env(self) -> str:
# This should never change, so we can safely cache it to avoid ser/de
if self.current_runtime_env is None:
Expand Down
14 changes: 10 additions & 4 deletions python/ray/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,19 @@ class TaskCancelledError(RayError):
cancelled.
"""

def __init__(self, task_id: Optional[TaskID] = None):
def __init__(
self, task_id: Optional[TaskID] = None, error_message: Optional[str] = None
):
self.task_id = task_id
self.error_message = error_message

def __str__(self):
if self.task_id is None:
return "This task or its dependency was cancelled by"
return "Task: " + str(self.task_id) + " was cancelled"
msg = ""
if self.task_id:
msg = "Task: " + str(self.task_id) + " was cancelled. "
if self.error_message:
msg += self.error_message
return msg


@PublicAPI
Expand Down
5 changes: 5 additions & 0 deletions python/ray/includes/libcoreworker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const c_string&,
const c_vector[c_string]&) nogil) run_on_util_worker_handler
(void(const CRayObject&) nogil) unhandled_exception_handler
(void(
const CTaskID &c_task_id,
const CRayFunction &ray_function,
const c_string c_name_of_concurrency_group_to_execute
) nogil) cancel_async_task
(void(c_string *stack_out) nogil) get_lang_stack
c_bool is_local_mode
int num_workers
Expand Down
1 change: 1 addition & 0 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ py_test_module_list(
"test_basic_4.py",
"test_basic_5.py",
"test_cancel.py",
"test_actor_cancel.py",
"test_dashboard_profiler.py",
"test_gcs_fault_tolerance.py",
"test_gcs_utils.py",
Expand Down
Loading

0 comments on commit 03ae779

Please sign in to comment.