diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index 5e3281c1c197..35925c8b455f 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -296,7 +296,11 @@ def _deserialize_object(self, data, metadata, object_ref): elif error_type == ErrorType.Value("LOCAL_RAYLET_DIED"): return LocalRayletDiedError() elif error_type == ErrorType.Value("TASK_CANCELLED"): - return TaskCancelledError() + error_message = "" + if data: + error_info = self._deserialize_error_info(data, metadata_fields) + error_message = error_info.error_message + return TaskCancelledError(error_message=error_message) elif error_type == ErrorType.Value("OBJECT_LOST"): return ObjectLostError( object_ref.hex(), object_ref.owner_address(), object_ref.call_site() diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index dafabf7e894b..aded960bd52f 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -2871,7 +2871,7 @@ def cancel(object_ref: "ray.ObjectRef", *, force: bool = False, recursive: bool if not isinstance(object_ref, ray.ObjectRef): raise TypeError( - "ray.cancel() only supported for non-actor object refs. " + "ray.cancel() only supported for object refs. " f"For actors, try ray.kill(). Got: {type(object_ref)}." ) return worker.core_worker.cancel_task(object_ref, force, recursive) diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 8a24f0d8ef2c..65ce1521a6e5 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -124,6 +124,8 @@ cdef class CoreWorker: object eventloop_for_default_cg object thread_for_default_cg object fd_to_cgname_dict + object task_id_to_future_lock + dict task_id_to_future object thread_pool_for_async_event_loop cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata, diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 43ada4130442..6fda4fe0d503 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -32,6 +32,8 @@ from typing import ( Generator, AsyncGenerator ) + +import concurrent from concurrent.futures import ThreadPoolExecutor from libc.stdint cimport ( @@ -1533,7 +1535,7 @@ cdef void execute_task( else: return core_worker.run_async_func_or_coro_in_event_loop( async_function, function_descriptor, - name_of_concurrency_group_to_execute, actor, + name_of_concurrency_group_to_execute, task_id, actor, *arguments, **kwarguments) return function(actor, *arguments, **kwarguments) @@ -1559,7 +1561,7 @@ cdef void execute_task( metadata_pairs, object_refs)) args = core_worker.run_async_func_or_coro_in_event_loop( deserialize_args, function_descriptor, - name_of_concurrency_group_to_execute) + name_of_concurrency_group_to_execute, None) else: # Defer task cancellation (SIGINT) until after the task argument # deserialization context has been left. @@ -1640,7 +1642,8 @@ cdef void execute_task( core_worker.run_async_func_or_coro_in_event_loop( execute_streaming_generator_async(context), function_descriptor, - name_of_concurrency_group_to_execute) + name_of_concurrency_group_to_execute, + task_id) else: execute_streaming_generator_sync(context) @@ -2188,6 +2191,28 @@ cdef void delete_spilled_objects_handler( job_id=None) +cdef void cancel_async_task( + const CTaskID &c_task_id, + const CRayFunction &ray_function, + const c_string c_name_of_concurrency_group_to_execute) nogil: + with gil: + function_descriptor = CFunctionDescriptorToPython( + ray_function.GetFunctionDescriptor()) + name_of_concurrency_group_to_execute = \ + c_name_of_concurrency_group_to_execute.decode("ascii") + task_id = TaskID(c_task_id.Binary()) + + worker = ray._private.worker.global_worker + eventloop, _ = worker.core_worker.get_event_loop( + function_descriptor, name_of_concurrency_group_to_execute) + future = worker.core_worker.get_queued_future(task_id) + if future is not None: + eventloop.call_soon_threadsafe(future.cancel) + # else, the task is already finished. If the task + # wasn't finished (task is queued on a client or server side), + # this method shouldn't have been called. + + cdef void unhandled_exception_handler(const CRayObject& error) nogil: with gil: worker = ray._private.worker.global_worker @@ -2945,6 +2970,7 @@ cdef class CoreWorker: options.restore_spilled_objects = restore_spilled_objects_handler options.delete_spilled_objects = delete_spilled_objects_handler options.unhandled_exception_handler = unhandled_exception_handler + options.cancel_async_task = cancel_async_task options.get_lang_stack = get_py_stack options.is_local_mode = local_mode options.kill_main = kill_main_task @@ -2965,6 +2991,8 @@ cdef class CoreWorker: self.fd_to_cgname_dict = None self.eventloop_for_default_cg = None self.current_runtime_env = None + self.task_id_to_future_lock = threading.Lock() + self.task_id_to_future = {} self.thread_pool_for_async_event_loop = None def shutdown(self): @@ -3721,7 +3749,8 @@ cdef class CoreWorker: CObjectID c_object_id = object_ref.native() CRayStatus status = CRayStatus.OK() - status = CCoreWorkerProcess.GetCoreWorker().CancelTask( + with nogil: + status = CCoreWorkerProcess.GetCoreWorker().CancelTask( c_object_id, force_kill, recursive) if not status.ok(): @@ -4148,6 +4177,7 @@ cdef class CoreWorker: func_or_coro: Union[Callable[[Any, Any], Awaitable[Any]], Awaitable], function_descriptor: FunctionDescriptor, specified_cgname: str, + task_id: Optional[TaskID], *args, **kwargs): """Run the async function or coroutine to the event loop. @@ -4157,6 +4187,10 @@ cdef class CoreWorker: func_or_coro: Async function (not a generator) or awaitable objects. function_descriptor: The function descriptor. specified_cgname: The name of a concurrent group. + task_id: The task ID to track the future. If None is provided + the future is not tracked with a task ID. + (e.g., When we deserialize the arguments, we don't want to + track the task_id -> future mapping). args: The arguments for the async function. kwargs: The keyword arguments for the async function. """ @@ -4181,11 +4215,24 @@ cdef class CoreWorker: coroutine = func_or_coro(*args, **kwargs) future = asyncio.run_coroutine_threadsafe(coroutine, eventloop) + if task_id: + with self.task_id_to_future_lock: + self.task_id_to_future[task_id] = asyncio.wrap_future( + future, loop=eventloop) + future.add_done_callback(lambda _: event.Notify()) with nogil: (CCoreWorkerProcess.GetCoreWorker() .YieldCurrentFiber(event)) - return future.result() + try: + result = future.result() + except concurrent.futures.CancelledError: + raise TaskCancelledError(task_id) + finally: + if task_id: + with self.task_id_to_future_lock: + self.task_id_to_future.pop(task_id) + return result def stop_and_join_asyncio_threads_if_exist(self): event_loops = [] @@ -4215,6 +4262,15 @@ cdef class CoreWorker: return (CCoreWorkerProcess.GetCoreWorker().GetWorkerContext() .CurrentActorMaxConcurrency()) + def get_queued_future(self, task_id: Optional[TaskID]) -> asyncio.Future: + """Get a asyncio.Future that's queued in the event loop.""" + with self.task_id_to_future_lock: + return self.task_id_to_future.get(task_id) + + def get_task_id_to_future(self): + # Testing-only + return self.task_id_to_future + def get_current_runtime_env(self) -> str: # This should never change, so we can safely cache it to avoid ser/de if self.current_runtime_env is None: diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 27a9eaa35e28..ce89c4fbe2d3 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -73,13 +73,19 @@ class TaskCancelledError(RayError): cancelled. """ - def __init__(self, task_id: Optional[TaskID] = None): + def __init__( + self, task_id: Optional[TaskID] = None, error_message: Optional[str] = None + ): self.task_id = task_id + self.error_message = error_message def __str__(self): - if self.task_id is None: - return "This task or its dependency was cancelled by" - return "Task: " + str(self.task_id) + " was cancelled" + msg = "" + if self.task_id: + msg = "Task: " + str(self.task_id) + " was cancelled. " + if self.error_message: + msg += self.error_message + return msg @PublicAPI diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 1f52bbea0af0..a6d6a72faee1 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -349,6 +349,11 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const c_string&, const c_vector[c_string]&) nogil) run_on_util_worker_handler (void(const CRayObject&) nogil) unhandled_exception_handler + (void( + const CTaskID &c_task_id, + const CRayFunction &ray_function, + const c_string c_name_of_concurrency_group_to_execute + ) nogil) cancel_async_task (void(c_string *stack_out) nogil) get_lang_stack c_bool is_local_mode int num_workers diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index ef958d7cc00d..c079af7ed7d7 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -44,6 +44,7 @@ py_test_module_list( "test_basic_4.py", "test_basic_5.py", "test_cancel.py", + "test_actor_cancel.py", "test_dashboard_profiler.py", "test_gcs_fault_tolerance.py", "test_gcs_utils.py", diff --git a/python/ray/tests/test_actor_cancel.py b/python/ray/tests/test_actor_cancel.py new file mode 100644 index 000000000000..d863ff945b8d --- /dev/null +++ b/python/ray/tests/test_actor_cancel.py @@ -0,0 +1,349 @@ +import asyncio +import os +import sys +import time + +import pytest + +import ray +from ray._private.test_utils import SignalActor, wait_for_condition +from ray.exceptions import TaskCancelledError +from ray.util.state import list_tasks + + +def test_input_validation(shutdown_only): + # Verify force=True is not working. + @ray.remote + class A: + async def f(self): + pass + + a = A.remote() + with pytest.raises(TypeError, match="force=True is not supported"): + ray.cancel(a.f.remote(), force=True) + + +def test_async_actor_cancel(shutdown_only): + """ + Test async actor task is canceled and + asyncio.CancelledError is raised within a task. + + If a task is canceled while it is executed, + it should raise RayTaskError. + + TODO(sang): It is awkward we raise RayTaskError + when a task is interrupted. Should we just raise + TaskCancelledError? It is an API change. + """ + ray.init(num_cpus=1) + + @ray.remote + class VerifyActor: + def __init__(self): + self.called = False + + def called(self): + print("called") + self.called = True + + def is_called(self): + print("is caled, ", self.called) + return self.called + + @ray.remote + class Actor: + async def f(self, verify_actor): + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + ray.get(verify_actor.called.remote()) + assert asyncio.get_current_task.canceled() + return True + return False + + v = VerifyActor.remote() + a = Actor.remote() + ref = a.f.remote(v) + ray.get(a.__ray_ready__.remote()) + ray.get(v.__ray_ready__.remote()) + ray.cancel(ref) + + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(ref) + + # Verify asyncio.CancelledError is raised from the actor task. + assert ray.get(v.is_called.remote()) + + +def test_async_actor_client_side_cancel(ray_start_cluster): + """ + Test a task is cancelled while it is queued on a client side. + It should raise ray.exceptions.TaskCancelledError. + """ + cluster = ray_start_cluster + cluster.add_node(num_cpus=0) + ray.init(address=cluster.address) + + @ray.remote(num_cpus=1) + class Actor: + def __init__(self): + self.f_called = False + + async def g(self, ref): + await asyncio.sleep(30) + + async def f(self): + self.f_called = True + await asyncio.sleep(5) + + def is_f_called(self): + return self.f_called + + @ray.remote + def f(): + time.sleep(100) + + # Test the case where a task is queued on a client side. + # Tasks are not sent until actor is created. + a = Actor.remote() + ref = a.f.remote() + ray.cancel(ref) + with pytest.raises(TaskCancelledError): + ray.get(ref) + + cluster.add_node(num_cpus=1) + assert not ray.get(a.is_f_called.remote()) + + # Test the case where it is canceled before dependencies + # are resolved. + a = Actor.remote() + ref_dep_not_resolved = a.g.remote(f.remote()) + ray.cancel(ref_dep_not_resolved) + with pytest.raises(TaskCancelledError): + ray.get(ref_dep_not_resolved) + + +@pytest.mark.skip( + reason=("The guarantee in this case is too weak now. " "Need more work.") +) +def test_in_flight_queued_requests_canceled(shutdown_only, monkeypatch): + """ + When there are large input size in-flight actor tasks + tasks are queued inside a RPC layer (core_worker_client.h) + In this case, we don't cancel a request from a client side + but wait until it is sent to the server side and cancel it. + See SendRequests() inside core_worker_client.h + """ + # Currently the max bytes is + # const int64_t kMaxBytesInFlight = 16 * 1024 * 1024. + # See core_worker_client.h. + input_arg = b"1" * 15 * 1024 # 15KB. + # Tasks are queued when there are more than 1024 tasks. + sig = SignalActor.remote() + + @ray.remote + class Actor: + def __init__(self, signal_actor): + self.signal_actor = signal_actor + + def f(self, input_arg): + ray.get(self.signal_actor.wait.remote()) + return True + + a = Actor.remote(sig) + refs = [a.f.remote(input_arg) for _ in range(5000)] + + # Wait until the first task runs. + wait_for_condition( + lambda: len(list_tasks(filters=[("STATE", "=", "RUNNING")])) == 1 + ) + + # Cancel all tasks. + for ref in refs: + ray.cancel(ref) + + # The first ref is in progress, so we pop it out + first_ref = refs.pop(0) + ray.get(sig.send.remote()) + + # Make sure all tasks that are queued (including queued + # due to in-flight bytes) are canceled. + canceled = 0 + for ref in refs: + try: + ray.get(ref) + except TaskCancelledError: + canceled += 1 + + # Verify at least half of tasks are canceled. + # Currently, the guarantee is weak because we cannot + # detect queued tasks due to inflight bytes limit. + # TODO(sang): Move the in flight bytes logic into + # actor submission queue instead of doing it inside + # core worker client. + assert canceled > 2500 + + # first ref shouldn't have been canceled. + assert ray.get(first_ref) + + +def test_async_actor_server_side_cancel(shutdown_only): + """ + Test Cancelation when a task is queued on a server side. + """ + + @ray.remote + class Actor: + async def f(self): + await asyncio.sleep(5) + + async def g(self): + await asyncio.sleep(0) + + a = Actor.options(max_concurrency=1).remote() + ray.get(a.__ray_ready__.remote()) + ref = a.f.remote() # noqa + # Queued on a server side. + # Task should not be executed at all. + refs = [a.g.remote() for _ in range(100)] + wait_for_condition( + lambda: len( + list_tasks( + filters=[ + ("name", "=", "Actor.g"), + ("STATE", "=", "SUBMITTED_TO_WORKER"), + ] + ) + ) + == 100 + ) + + for ref in refs: + ray.cancel(ref) + tasks = list_tasks(filters=[("name", "=", "Actor.g")]) + + for ref in refs: + with pytest.raises(TaskCancelledError, match=ref.task_id().hex()): + ray.get(ref) + + # Verify the task is submitted to the worker and never executed + # assert task.state == "SUBMITTED_TO_WORKER" + for task in tasks: + assert task.state == "SUBMITTED_TO_WORKER" + + +def test_async_actor_cancel_after_task_finishes(shutdown_only): + @ray.remote + class Actor: + async def f(self): + await asyncio.sleep(5) + + async def empty(self): + pass + + # Cancel after task finishes + a = Actor.options(max_concurrency=1).remote() + ref = a.empty.remote() + ref2 = a.empty.remote() + ray.get([ref, ref2]) + ray.cancel(ref) + ray.cancel(ref2) + # Exceptions shouldn't be raised. + ray.get([ref, ref2]) + + +def test_async_actor_cancel_restart(ray_start_cluster, monkeypatch): + """ + Verify a cancelation works if actor is restarted. + """ + with monkeypatch.context() as m: + # This will slow down the cancelation RPC so that + # cancel won't succeed until a node is killed. + m.setenv( + "RAY_testing_asio_delay_us", + "CoreWorkerService.grpc_server.CancelTask=3000000:3000000", + ) + cluster = ray_start_cluster + cluster.add_node(num_cpus=0) + ray.init(address=cluster.address) + node = cluster.add_node(num_cpus=1) + + @ray.remote(num_cpus=1, max_restarts=-1, max_task_retries=-1) + class Actor: + async def f(self): + await asyncio.sleep(10) + + a = Actor.remote() + ref = a.f.remote() + # This guarantees that a.f.remote() is executed + ray.get(a.__ray_ready__.remote()) + ray.cancel(ref) + cluster.remove_node(node) + r, ur = ray.wait([ref]) + # When cancel is called, the task won't be retried anymore. + # Since an actor is dead, in this case, it will raise + # RayActorError. + with pytest.raises(ray.exceptions.RayActorError): + ray.get(ref) + + # This will restart actor, but task won't be retried. + cluster.add_node(num_cpus=1) + # Verify actor is restarted. f should be retried + ray.get(a.__ray_ready__.remote()) + with pytest.raises(ray.exceptions.RayActorError): + ray.get(ref) + + +def test_remote_cancel(ray_start_regular): + @ray.remote + class Actor: + async def sleep(self): + await asyncio.sleep(1000) + + @ray.remote + def f(refs): + ref = refs[0] + ray.cancel(ref) + + a = Actor.remote() + sleep_ref = a.sleep.remote() + wait_for_condition(lambda: list_tasks(filters=[("name", "=", "Actor.sleep")])) + ref = f.remote([sleep_ref]) # noqa + + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(sleep_ref) + + +@pytest.mark.skip(reason=("Currently not passing. There's one edge case to fix.")) +def test_cancel_stress(shutdown_only): + ray.init() + + @ray.remote + class Actor: + async def sleep(self): + await asyncio.sleep(1000) + + actors = [Actor.remote() for _ in range(30)] + + refs = [] + for _ in range(20): + for actor in actors: + for i in range(100): + ref = actor.sleep.remote() + refs.append(ref) + if i % 2 == 0: + ray.cancel(ref) + + for ref in refs: + ray.cancel(ref) + + for ref in refs: + with pytest.raises((ray.exceptions.RayTaskError, TaskCancelledError)): + ray.get(ref) + + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/mock/ray/core_worker/transport/direct_actor_transport.h b/src/mock/ray/core_worker/transport/direct_actor_transport.h index 1904a3517ff0..e47de46cd9fd 100644 --- a/src/mock/ray/core_worker/transport/direct_actor_transport.h +++ b/src/mock/ray/core_worker/transport/direct_actor_transport.h @@ -57,7 +57,7 @@ class MockSchedulingQueue : public SchedulingQueue { (int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, - std::function reject_request, + std::function reject_request, rpc::SendReplyCallback send_reply_callback, const std::string &concurrency_group_name, const ray::FunctionDescriptor &function_descriptor, diff --git a/src/ray/common/asio/asio_util.h b/src/ray/common/asio/asio_util.h index 1bd513c2fe4f..232e397e5ce7 100644 --- a/src/ray/common/asio/asio_util.h +++ b/src/ray/common/asio/asio_util.h @@ -17,6 +17,8 @@ #include #include +#include "ray/common/asio/instrumented_io_context.h" + template std::shared_ptr execute_after( instrumented_io_context &io_context, diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index faf7713804e9..01a2735cab3b 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -203,7 +203,10 @@ int TaskSpecification::GetRuntimeEnvHash() const { } const SchedulingClass TaskSpecification::GetSchedulingClass() const { - RAY_CHECK(sched_cls_id_ > 0); + if (!IsActorTask()) { + // Actor task doesn't have scheudling id, so we don't need to check this. + RAY_CHECK(sched_cls_id_ > 0); + } return sched_cls_id_; } diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 6971e19124ab..adae1fc335c8 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1,3 +1,4 @@ + // Copyright 2017 The Ray Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -2256,23 +2257,42 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, Status CoreWorker::CancelTask(const ObjectID &object_id, bool force_kill, bool recursive) { - if (actor_manager_->CheckActorHandleExists(object_id.TaskId().ActorId())) { - return Status::Invalid("Actor task cancellation is not supported."); - } rpc::Address obj_addr; if (!reference_counter_->GetOwner(object_id, &obj_addr)) { return Status::Invalid("No owner found for object."); } + if (obj_addr.SerializeAsString() != rpc_address_.SerializeAsString()) { + // We don't have CancelRemoteTask for direct_actor_submitter_ + // because it requires the same implementation. + RAY_LOG(DEBUG) << "Request to cancel a task of object id " << object_id + << " to an owner " << obj_addr.SerializeAsString(); return direct_task_submitter_->CancelRemoteTask( object_id, obj_addr, force_kill, recursive); } auto task_spec = task_manager_->GetTaskSpec(object_id.TaskId()); - if (task_spec.has_value() && !task_spec.value().IsActorCreationTask()) { + if (!task_spec.has_value()) { + // Task is already finished. + RAY_LOG(DEBUG) << "Cancel request is ignored because the task is already canceled " + "for an object id " + << object_id; + return Status::OK(); + } + + if (task_spec.value().IsActorCreationTask()) { + RAY_LOG(FATAL) << "Cannot cancel actor creation tasks"; + } + + if (task_spec->IsActorTask()) { + if (force_kill) { + return Status::Invalid("force=True is not supported for actor tasks."); + } + + return direct_actor_submitter_->CancelTask(task_spec.value(), recursive); + } else { return direct_task_submitter_->CancelTask(task_spec.value(), force_kill, recursive); } - return Status::OK(); } Status CoreWorker::CancelChildren(const TaskID &task_id, bool force_kill) { @@ -3531,6 +3551,51 @@ void CoreWorker::HandleCancelTask(rpc::CancelTaskRequest request, rpc::CancelTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { TaskID task_id = TaskID::FromBinary(request.intended_task_id()); + bool force_kill = request.force_kill(); + bool recursive = request.recursive(); + const auto ¤t_actor_id = worker_context_.GetCurrentActorID(); + const auto caller_worker_id = WorkerID::FromBinary(request.caller_worker_id()); + + auto on_cancel_callback = [this, + reply, + send_reply_callback = std::move(send_reply_callback), + force_kill, + task_id](bool success, bool requested_task_running) { + reply->set_attempt_succeeded(success); + reply->set_requested_task_running(requested_task_running); + send_reply_callback(Status::OK(), nullptr, nullptr); + + // Do force kill after reply callback sent. + if (force_kill) { + // We grab the lock again to make sure that we are force-killing the correct + // task. This is guaranteed not to deadlock because ForceExit should not + // require any other locks. + absl::MutexLock lock(&mutex_); + if (main_thread_task_id_ == task_id) { + ForceExit(rpc::WorkerExitType::INTENDED_USER_EXIT, + absl::StrCat("The worker exits because the task ", + main_thread_task_name_, + " has received a force ray.cancel request.")); + } + } + }; + + if (task_id.ActorId() == current_actor_id) { + RAY_LOG(INFO) << "Cancel an actor task " << task_id << " for an actor " + << current_actor_id; + CancelActorTaskOnExecutor( + caller_worker_id, task_id, force_kill, recursive, on_cancel_callback); + } else { + RAY_CHECK(current_actor_id.IsNil()); + RAY_LOG(INFO) << "Cancel a normal task " << task_id; + CancelTaskOnExecutor(task_id, force_kill, recursive, on_cancel_callback); + } +} + +void CoreWorker::CancelTaskOnExecutor(TaskID task_id, + bool force_kill, + bool recursive, + OnCanceledCallback on_canceled) { bool requested_task_running; { absl::MutexLock lock(&mutex_); @@ -3545,7 +3610,7 @@ void CoreWorker::HandleCancelTask(rpc::CancelTaskRequest request, // the kill callback runs; the kill callback is responsible for also making // sure it cancels the right task. // See https://github.com/ray-project/ray/issues/29739. - if (requested_task_running && !request.force_kill()) { + if (requested_task_running && !force_kill) { RAY_LOG(INFO) << "Cancelling a running task with id: " << task_id; success = options_.kill_main(task_id); } else if (!requested_task_running) { @@ -3555,29 +3620,84 @@ void CoreWorker::HandleCancelTask(rpc::CancelTaskRequest request, // normal tasks, and remove it if found. success = direct_task_receiver_->CancelQueuedNormalTask(task_id); } - if (request.recursive()) { - auto recursive_cancel = CancelChildren(task_id, request.force_kill()); + if (recursive) { + auto recursive_cancel = CancelChildren(task_id, force_kill); if (!recursive_cancel.ok()) { RAY_LOG(ERROR) << recursive_cancel.ToString(); } } - reply->set_attempt_succeeded(success); - reply->set_requested_task_running(requested_task_running); - send_reply_callback(Status::OK(), nullptr, nullptr); + on_canceled(/*success*/ success, /*requested_task_running*/ requested_task_running); +} - // Do force kill after reply callback sent. - if (request.force_kill()) { - // We grab the lock again to make sure that we are force-killing the correct - // task. This is guaranteed not to deadlock because ForceExit should not - // require any other locks. - absl::MutexLock lock(&mutex_); - if (main_thread_task_id_ == task_id) { - ForceExit(rpc::WorkerExitType::INTENDED_USER_EXIT, - absl::StrCat("The worker exits because the task ", - main_thread_task_name_, - " has received a force ray.cancel request.")); +void CoreWorker::CancelActorTaskOnExecutor(WorkerID caller_worker_id, + TaskID task_id, + bool force_kill, + bool recursive, + OnCanceledCallback on_canceled) { + RAY_CHECK(!force_kill); + auto is_async_actor = worker_context_.CurrentActorIsAsync(); + + auto cancel = [this, + task_id, + caller_worker_id, + on_canceled = std::move(on_canceled), + is_async_actor]() { + bool is_task_running; + TaskSpecification spec; + RayFunction func; + std::string concurrency_group_name; + + bool is_task_queued_or_executing = + direct_task_receiver_->CancelQueuedActorTask(caller_worker_id, task_id); + + // If a task is already running, we send a cancel request. + // Right now, we can only cancel async actor tasks. + if (is_task_queued_or_executing) { + { + absl::MutexLock lock(&mutex_); + auto it = current_tasks_.find(task_id); + is_task_running = it != current_tasks_.end(); + if (is_task_running) { + spec = it->second; + func = RayFunction(spec.GetLanguage(), spec.FunctionDescriptor()); + concurrency_group_name = spec.ConcurrencyGroupName(); + } + } + + if (is_task_running && is_async_actor) { + options_.cancel_async_task(task_id, func, concurrency_group_name); + } + // TODO(sang): else support regular actor interrupt. } + + // If `is_task_queued_or_executing`is true, task was either queued or run. + // If a task is queued, it is guaranteed to be canceled by + // CancelQueuedActorTask. If a task is executing, we try canceling + // them, but it is not guaranteed. For both cases, we consider cancelation + // succeeds. If `is_task_queued_or_executing` is false, it means task is finished + // or not received yet. In this case, we mark `success` as false, so that the + // caller can retry cancel RPCs. Note that the caller knows exactly when a task is + // finished from their end, so it won't infinitely retry cancel RPCs. + // requested_task_running is not used, so we just always mark it as false. + on_canceled(/*success*/ is_task_queued_or_executing, + /*requested_task_running*/ false); + }; + + if (is_async_actor) { + // If it is an async actor, post it to an execution service + // to avoid thread issues. Note that when it is an async actor + // task_execution_service_ won't actually run a task but it will + // just create coroutines. + task_execution_service_.post([cancel = std::move(cancel)]() { cancel(); }, + "CoreWorker.CancelActorTaskOnExecutor"); + } else { + // For regular actor, we cannot post it to task_execution_service because + // main thread is blocked. Threaded actor can do both (dispatching to + // task execution service, or just directly call it in io_service). + // There's no special reason why we don't dispatch + // cancel to task_execution_service_ for threaded actors. + cancel(); } } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 16e24c68914b..4cb8f9eebd67 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1219,6 +1219,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { using SetResultCallback = std::function, ObjectID object_id, void *)>; + using OnCanceledCallback = std::function; + /// Perform async get from the object store. /// /// \param[in] object_id The id to call get on. @@ -1549,6 +1551,31 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { Status WaitForActorRegistered(const std::vector &ids); + /// Cancel a normal task (non-actor-task) queued or running in the current worker. + /// + /// \param intended_task_id The ID of a task to cancel. + /// \param force_kill If true, kill the worker. + /// \param recursive If true, cancel all children tasks of the intended_task_id. + /// \param on_canceled Callback called after a task is canceled. + /// It has two inputs, which corresponds to requested_task_running (if task is still + /// running after a cancelation attempt is done) and attempt_succeeded (if task + /// is canceled, and a caller doesn't have to retry). + void CancelTaskOnExecutor(TaskID intended_task_id, + bool force_kill, + bool recursive, + OnCanceledCallback on_canceled); + + /// Cancel an actor task queued or running in the current worker. + /// + /// See params in CancelTaskOnExecutor. + /// For the actor task cancel protocol, see the docstring of + /// direct_actor_task_submitter.h::CancelTask. + void CancelActorTaskOnExecutor(WorkerID caller_worker_id, + TaskID intended_task_id, + bool force_kill, + bool recursive, + OnCanceledCallback on_canceled); + /// Shared state of the worker. Includes process-level and thread-level state. /// TODO(edoakes): we should move process-level state into this class and make /// this a ThreadContext. diff --git a/src/ray/core_worker/core_worker_options.h b/src/ray/core_worker/core_worker_options.h index 05623bb25d36..e5825c29f61f 100644 --- a/src/ray/core_worker/core_worker_options.h +++ b/src/ray/core_worker/core_worker_options.h @@ -86,6 +86,7 @@ struct CoreWorkerOptions { unhandled_exception_handler(nullptr), get_lang_stack(nullptr), kill_main(nullptr), + cancel_async_task(nullptr), is_local_mode(false), terminate_asyncio_thread(nullptr), serialized_job_config(""), @@ -162,6 +163,10 @@ struct CoreWorkerOptions { // Function that tries to interrupt the currently running Python thread if its // task ID matches the one given. std::function kill_main; + std::function + cancel_async_task; /// Is local mode being used. bool is_local_mode; /// The function to destroy asyncio event and loops. diff --git a/src/ray/core_worker/test/scheduling_queue_test.cc b/src/ray/core_worker/test/scheduling_queue_test.cc index e96ccc2bb374..910cfea593f6 100644 --- a/src/ray/core_worker/test/scheduling_queue_test.cc +++ b/src/ray/core_worker/test/scheduling_queue_test.cc @@ -38,7 +38,7 @@ class MockActorSchedulingQueue { void Add(int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, - std::function reject_request, + std::function reject_request, rpc::SendReplyCallback send_reply_callback = nullptr, TaskID task_id = TaskID::Nil(), const std::vector &dependencies = {}) { @@ -81,7 +81,9 @@ TEST(SchedulingQueueTest, TestInOrder) { int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; - auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + auto fn_rej = [&n_rej](const Status &status, rpc::SendReplyCallback callback) { + n_rej++; + }; queue.Add(0, -1, fn_ok, fn_rej, nullptr); queue.Add(1, -1, fn_ok, fn_rej, nullptr); queue.Add(2, -1, fn_ok, fn_rej, nullptr); @@ -102,7 +104,9 @@ TEST(SchedulingQueueTest, TestWaitForObjects) { int n_rej = 0; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; - auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + auto fn_rej = [&n_rej](const Status &status, rpc::SendReplyCallback callback) { + n_rej++; + }; queue.Add(0, -1, fn_ok, fn_rej, nullptr); queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1})); queue.Add(2, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj2})); @@ -129,7 +133,9 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { int n_rej = 0; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; - auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + auto fn_rej = [&n_rej](const Status &status, rpc::SendReplyCallback callback) { + n_rej++; + }; queue.Add(0, -1, fn_ok, fn_rej, nullptr); queue.Add(1, -1, fn_ok, fn_rej, nullptr, TaskID::Nil(), ObjectIdsToRefs({obj1})); @@ -147,7 +153,9 @@ TEST(SchedulingQueueTest, TestOutOfOrder) { int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; - auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + auto fn_rej = [&n_rej](const Status &status, rpc::SendReplyCallback callback) { + n_rej++; + }; queue.Add(2, -1, fn_ok, fn_rej, nullptr); queue.Add(0, -1, fn_ok, fn_rej, nullptr); queue.Add(3, -1, fn_ok, fn_rej, nullptr); @@ -164,7 +172,9 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) { int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; - auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + auto fn_rej = [&n_rej](const Status &status, rpc::SendReplyCallback callback) { + n_rej++; + }; queue.Add(2, -1, fn_ok, fn_rej, nullptr); queue.Add(0, -1, fn_ok, fn_rej, nullptr); queue.Add(3, -1, fn_ok, fn_rej, nullptr); @@ -186,7 +196,9 @@ TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) { int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; - auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + auto fn_rej = [&n_rej](const Status &status, rpc::SendReplyCallback callback) { + n_rej++; + }; queue.Add(2, 2, fn_ok, fn_rej, nullptr); queue.Add(3, 2, fn_ok, fn_rej, nullptr); queue.Add(1, 2, fn_ok, fn_rej, nullptr); @@ -201,7 +213,9 @@ TEST(SchedulingQueueTest, TestCancelQueuedTask) { int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok](rpc::SendReplyCallback callback) { n_ok++; }; - auto fn_rej = [&n_rej](rpc::SendReplyCallback callback) { n_rej++; }; + auto fn_rej = [&n_rej](const Status &status, rpc::SendReplyCallback callback) { + n_rej++; + }; queue->Add(-1, -1, fn_ok, fn_rej, nullptr, "", FunctionDescriptorBuilder::Empty()); queue->Add(-1, -1, fn_ok, fn_rej, nullptr, "", FunctionDescriptorBuilder::Empty()); queue->Add(-1, -1, fn_ok, fn_rej, nullptr, "", FunctionDescriptorBuilder::Empty()); diff --git a/src/ray/core_worker/transport/actor_scheduling_queue.cc b/src/ray/core_worker/transport/actor_scheduling_queue.cc index 31a4d8711ac6..a85387db2814 100644 --- a/src/ray/core_worker/transport/actor_scheduling_queue.cc +++ b/src/ray/core_worker/transport/actor_scheduling_queue.cc @@ -68,15 +68,16 @@ size_t ActorSchedulingQueue::Size() const { } /// Add a new actor task's callbacks to the worker queue. -void ActorSchedulingQueue::Add(int64_t seq_no, - int64_t client_processed_up_to, - std::function accept_request, - std::function reject_request, - rpc::SendReplyCallback send_reply_callback, - const std::string &concurrency_group_name, - const ray::FunctionDescriptor &function_descriptor, - TaskID task_id, - const std::vector &dependencies) { +void ActorSchedulingQueue::Add( + int64_t seq_no, + int64_t client_processed_up_to, + std::function accept_request, + std::function reject_request, + rpc::SendReplyCallback send_reply_callback, + const std::string &concurrency_group_name, + const ray::FunctionDescriptor &function_descriptor, + TaskID task_id, + const std::vector &dependencies) { // A seq_no of -1 means no ordering constraint. Actor tasks must be executed in order. RAY_CHECK(seq_no != -1); @@ -95,6 +96,10 @@ void ActorSchedulingQueue::Add(int64_t seq_no, dependencies.size() > 0, concurrency_group_name, function_descriptor); + { + absl::MutexLock lock(&mu_); + pending_task_id_to_is_canceled.emplace(task_id, false); + } if (dependencies.size() > 0) { waiter_.Wait(dependencies, [seq_no, this]() { @@ -109,13 +114,16 @@ void ActorSchedulingQueue::Add(int64_t seq_no, ScheduleRequests(); } -// We don't allow the cancellation of actor tasks, so invoking CancelTaskIfFound -// results in a fatal error. bool ActorSchedulingQueue::CancelTaskIfFound(TaskID task_id) { - RAY_CHECK(false) << "Cannot cancel actor tasks"; - // The return instruction will never be executed, but we need to include it - // nonetheless because this is a non-void function. - return false; + absl::MutexLock lock(&mu_); + if (pending_task_id_to_is_canceled.find(task_id) != + pending_task_id_to_is_canceled.end()) { + // Mark the task is canceled. + pending_task_id_to_is_canceled[task_id] = true; + return true; + } else { + return false; + } } /// Schedules as many requests as possible in sequence. @@ -126,7 +134,11 @@ void ActorSchedulingQueue::ScheduleRequests() { auto head = pending_actor_tasks_.begin(); RAY_LOG(ERROR) << "Cancelling stale RPC with seqno " << pending_actor_tasks_.begin()->first << " < " << next_seq_no_; - head->second.Cancel(); + head->second.Cancel(Status::Invalid("client cancelled stale rpc")); + { + absl::MutexLock lock(&mu_); + pending_task_id_to_is_canceled.erase(head->second.TaskID()); + } pending_actor_tasks_.erase(head); } @@ -136,21 +148,26 @@ void ActorSchedulingQueue::ScheduleRequests() { pending_actor_tasks_.begin()->second.CanExecute()) { auto head = pending_actor_tasks_.begin(); auto request = head->second; + auto task_id = head->second.TaskID(); if (is_asyncio_) { // Process async actor task. auto fiber = fiber_state_manager_->GetExecutor(request.ConcurrencyGroupName(), request.FunctionDescriptor()); - fiber->EnqueueFiber([request]() mutable { request.Accept(); }); + fiber->EnqueueFiber([this, request, task_id]() mutable { + AcceptRequestOrRejectIfCanceled(task_id, request); + }); } else { // Process actor tasks. RAY_CHECK(pool_manager_ != nullptr); auto pool = pool_manager_->GetExecutor(request.ConcurrencyGroupName(), request.FunctionDescriptor()); if (pool == nullptr) { - request.Accept(); + AcceptRequestOrRejectIfCanceled(task_id, request); } else { - pool->Post([request]() mutable { request.Accept(); }); + pool->Post([this, request, task_id]() mutable { + AcceptRequestOrRejectIfCanceled(task_id, request); + }); } } pending_actor_tasks_.erase(head); @@ -182,11 +199,38 @@ void ActorSchedulingQueue::OnSequencingWaitTimeout() { << ", cancelling all queued tasks"; while (!pending_actor_tasks_.empty()) { auto head = pending_actor_tasks_.begin(); - head->second.Cancel(); + head->second.Cancel(Status::Invalid("client cancelled stale rpc")); next_seq_no_ = std::max(next_seq_no_, head->first + 1); + { + absl::MutexLock lock(&mu_); + pending_task_id_to_is_canceled.erase(head->second.TaskID()); + } pending_actor_tasks_.erase(head); } } +void ActorSchedulingQueue::AcceptRequestOrRejectIfCanceled(TaskID task_id, + InboundRequest &request) { + bool is_canceled = false; + { + absl::MutexLock lock(&mu_); + auto it = pending_task_id_to_is_canceled.find(task_id); + if (it != pending_task_id_to_is_canceled.end()) { + is_canceled = it->second; + } + } + + // Accept can be very long, and we shouldn't hold a lock. + if (is_canceled) { + request.Cancel( + Status::SchedulingCancelled("Task is canceled before it is scheduled.")); + } else { + request.Accept(); + } + + absl::MutexLock lock(&mu_); + pending_task_id_to_is_canceled.erase(task_id); +} + } // namespace core } // namespace ray diff --git a/src/ray/core_worker/transport/actor_scheduling_queue.h b/src/ray/core_worker/transport/actor_scheduling_queue.h index 126577ba9a0d..f94266bed4dc 100644 --- a/src/ray/core_worker/transport/actor_scheduling_queue.h +++ b/src/ray/core_worker/transport/actor_scheduling_queue.h @@ -59,21 +59,27 @@ class ActorSchedulingQueue : public SchedulingQueue { void Add(int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, - std::function reject_request, + std::function reject_request, rpc::SendReplyCallback send_reply_callback, const std::string &concurrency_group_name, const ray::FunctionDescriptor &function_descriptor, TaskID task_id = TaskID::Nil(), const std::vector &dependencies = {}) override; - // We don't allow the cancellation of actor tasks, so invoking CancelTaskIfFound - // results in a fatal error. + /// Cancel the actor task in the queue. + /// Tasks are in the queue if it is either queued, or executing. + /// Return true if a task is in the queue. False otherwise. + /// This method has to be THREAD-SAFE. bool CancelTaskIfFound(TaskID task_id) override; /// Schedules as many requests as possible in sequence. void ScheduleRequests() override; private: + /// Accept the given InboundRequest or reject it if a task id is canceled via + /// CancelTaskIfFound. + void AcceptRequestOrRejectIfCanceled(TaskID task_id, InboundRequest &request); + /// Called when we time out waiting for an earlier task to show up. void OnSequencingWaitTimeout(); /// Max time in seconds to wait for dependencies to show up. @@ -97,6 +103,11 @@ class ActorSchedulingQueue : public SchedulingQueue { /// Whether we should enqueue requests into asyncio pool. Setting this to true /// will instantiate all tasks as fibers that can be yielded. bool is_asyncio_ = false; + /// Mutext to protect attributes used for thread safe APIs. + absl::Mutex mu_; + /// A map of actor task IDs -> is_canceled + /// Pending means tasks are queued or running. + absl::flat_hash_map pending_task_id_to_is_canceled GUARDED_BY(mu_); friend class SchedulingQueueTest; }; diff --git a/src/ray/core_worker/transport/actor_scheduling_util.cc b/src/ray/core_worker/transport/actor_scheduling_util.cc index a5117b9a4714..217214f5d688 100644 --- a/src/ray/core_worker/transport/actor_scheduling_util.cc +++ b/src/ray/core_worker/transport/actor_scheduling_util.cc @@ -21,7 +21,7 @@ InboundRequest::InboundRequest() {} InboundRequest::InboundRequest( std::function accept_callback, - std::function reject_callback, + std::function reject_callback, rpc::SendReplyCallback send_reply_callback, class TaskID task_id, bool has_dependencies, @@ -36,7 +36,9 @@ InboundRequest::InboundRequest( has_pending_dependencies_(has_dependencies) {} void InboundRequest::Accept() { accept_callback_(std::move(send_reply_callback_)); } -void InboundRequest::Cancel() { reject_callback_(std::move(send_reply_callback_)); } +void InboundRequest::Cancel(const Status &status) { + reject_callback_(status, std::move(send_reply_callback_)); +} bool InboundRequest::CanExecute() const { return !has_pending_dependencies_; } ray::TaskID InboundRequest::TaskID() const { return task_id_; } diff --git a/src/ray/core_worker/transport/actor_scheduling_util.h b/src/ray/core_worker/transport/actor_scheduling_util.h index 75728cbce542..bfa7cad2c08d 100644 --- a/src/ray/core_worker/transport/actor_scheduling_util.h +++ b/src/ray/core_worker/transport/actor_scheduling_util.h @@ -27,16 +27,17 @@ namespace core { class InboundRequest { public: InboundRequest(); - InboundRequest(std::function accept_callback, - std::function reject_callback, - rpc::SendReplyCallback send_reply_callback, - TaskID task_id, - bool has_dependencies, - const std::string &concurrency_group_name, - const ray::FunctionDescriptor &function_descriptor); + InboundRequest( + std::function accept_callback, + std::function reject_callback, + rpc::SendReplyCallback send_reply_callback, + TaskID task_id, + bool has_dependencies, + const std::string &concurrency_group_name, + const ray::FunctionDescriptor &function_descriptor); void Accept(); - void Cancel(); + void Cancel(const Status &status); bool CanExecute() const; ray::TaskID TaskID() const; const std::string &ConcurrencyGroupName() const; @@ -45,7 +46,7 @@ class InboundRequest { private: std::function accept_callback_; - std::function reject_callback_; + std::function reject_callback_; rpc::SendReplyCallback send_reply_callback_; ray::TaskID task_id_; diff --git a/src/ray/core_worker/transport/actor_submit_queue.h b/src/ray/core_worker/transport/actor_submit_queue.h index 85a4981ef791..e59af54897f0 100644 --- a/src/ray/core_worker/transport/actor_submit_queue.h +++ b/src/ray/core_worker/transport/actor_submit_queue.h @@ -58,6 +58,10 @@ class IActorSubmitQueue { virtual void MarkDependencyFailed(uint64_t sequence_no) = 0; /// Mark a task's dependency is resolved thus ready to send. virtual void MarkDependencyResolved(uint64_t sequence_no) = 0; + // Mark a task has been canceled. + // If a task hasn't been sent yet, this API will guarantee a task won't be + // popped via PopNextTaskToSend. + virtual void MarkTaskCanceled(uint64_t sequence_no) = 0; /// Clear the queue and returns all tasks ids that haven't been sent yet. virtual std::vector ClearAllTasks() = 0; /// Find next task to send. diff --git a/src/ray/core_worker/transport/direct_actor_task_submitter.cc b/src/ray/core_worker/transport/direct_actor_task_submitter.cc index 0451e5c0ae1a..cc9345d732a4 100644 --- a/src/ray/core_worker/transport/direct_actor_task_submitter.cc +++ b/src/ray/core_worker/transport/direct_actor_task_submitter.cc @@ -377,6 +377,7 @@ void CoreWorkerDirectActorTaskSubmitter::SendPendingTasks(const ActorID &actor_i if (!task.has_value()) { break; } + io_service_.post( [this, task_spec = std::move(task.value().first)] { rpc::PushTaskReply reply; @@ -435,6 +436,8 @@ void CoreWorkerDirectActorTaskSubmitter::ResendOutOfOrderTasks(const ActorID &ac void CoreWorkerDirectActorTaskSubmitter::PushActorTask(ClientQueue &queue, const TaskSpecification &task_spec, bool skip_queue) { + const auto task_id = task_spec.TaskId(); + auto request = std::make_unique(); // NOTE(swang): CopyFrom is needed because if we use Swap here and the task // fails, then the task data will be gone when the TaskManager attempts to @@ -444,7 +447,6 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(ClientQueue &queue, request->set_intended_worker_id(queue.worker_id); request->set_sequence_number(queue.actor_submit_queue->GetSequenceNumber(task_spec)); - const auto task_id = task_spec.TaskId(); const auto actor_id = task_spec.ActorId(); const auto actor_counter = task_spec.ActorCounter(); const auto num_queued = queue.inflight_task_callbacks.size(); @@ -510,6 +512,19 @@ void CoreWorkerDirectActorTaskSubmitter::HandlePushTaskReply( } else if (status.ok()) { task_finisher_.CompletePendingTask( task_id, reply, addr, reply.is_application_error()); + } else if (status.IsSchedulingCancelled()) { + std::ostringstream stream; + stream << "The task " << task_id << " is canceled from an actor " << actor_id + << " before it executes."; + const auto &msg = stream.str(); + RAY_LOG(DEBUG) << msg; + rpc::RayErrorInfo error_info; + error_info.set_error_message(msg); + error_info.set_error_type(rpc::ErrorType::TASK_CANCELLED); + GetTaskFinisherWithoutMu().FailPendingTask(task_spec.TaskId(), + rpc::ErrorType::TASK_CANCELLED, + /*status*/ nullptr, + &error_info); } else { bool is_actor_dead = false; bool fail_immediatedly = false; @@ -623,5 +638,138 @@ std::string CoreWorkerDirectActorTaskSubmitter::DebugString( return stream.str(); } +void CoreWorkerDirectActorTaskSubmitter::RetryCancelTask(TaskSpecification task_spec, + bool recursive, + int64_t milliseconds) { + RAY_LOG(DEBUG) << "Task " << task_spec.TaskId() << " cancelation will be retried in " + << milliseconds << " ms"; + execute_after( + io_service_, + [this, task_spec = std::move(task_spec), recursive] { + RAY_UNUSED(CancelTask(task_spec, recursive)); + }, + std::chrono::milliseconds(milliseconds)); +} + +Status CoreWorkerDirectActorTaskSubmitter::CancelTask(TaskSpecification task_spec, + bool recursive) { + // We don't support force_kill = true for actor tasks. + bool force_kill = false; + RAY_LOG(INFO) << "Cancelling a task: " << task_spec.TaskId() + << " for an actor: " << task_spec.ActorId() + << " force_kill: " << force_kill << " recursive: " << recursive; + + // Tasks are in one of the following states. + // - dependencies not resolved + // - queued + // - sent + // - finished. + + const auto actor_id = task_spec.ActorId(); + const auto &task_id = task_spec.TaskId(); + auto send_pos = task_spec.ActorCounter(); + + // Shouldn't hold a lock while accessing task_finisher_. + // Task is already canceled or finished. + if (!GetTaskFinisherWithoutMu().MarkTaskCanceled(task_id)) { + RAY_LOG(DEBUG) << "a task " << task_id << " is already finished or canceled"; + return Status::OK(); + } + + auto task_queued = false; + { + absl::MutexLock lock(&mu_); + + auto queue = client_queues_.find(actor_id); + RAY_CHECK(queue != client_queues_.end()); + if (queue->second.state == rpc::ActorTableData::DEAD) { + // No need to decrement cur_pending_calls because it doesn't matter. + RAY_LOG(DEBUG) << "a task " << task_id + << "'s actor is already dead. Ignoring the cancel request."; + return Status::OK(); + } + + task_queued = queue->second.actor_submit_queue->Contains(send_pos); + if (task_queued) { + auto dep_resolved = queue->second.actor_submit_queue->Get(send_pos).second; + if (!dep_resolved) { + RAY_LOG(DEBUG) + << "a task " << task_id + << " has been resolving dependencies. Cancel to resolve dependencies"; + resolver_.CancelDependencyResolution(task_id); + } + RAY_LOG(DEBUG) << "a task " << task_id + << " was queued. Mark a task is canceled from a queue."; + queue->second.actor_submit_queue->MarkTaskCanceled(send_pos); + } + } + + // Fail a request immediately if it is still queued. + // The task won't be sent to an actor in this case. + // We cannot hold a lock when calling `FailOrRetryPendingTask`. + if (task_queued) { + rpc::RayErrorInfo error_info; + std::ostringstream stream; + stream << "The task " << task_id << " is canceled from an actor " << actor_id + << " before it executes."; + error_info.set_error_message(stream.str()); + error_info.set_error_type(rpc::ErrorType::TASK_CANCELLED); + GetTaskFinisherWithoutMu().FailOrRetryPendingTask( + task_id, rpc::ErrorType::TASK_CANCELLED, /*status*/ nullptr, &error_info); + return Status::OK(); + } + + // At this point, the task is in "sent" state and not finished yet. + // We cannot guarantee a cancel request is received "after" a task + // is submitted because gRPC is not ordered. To get around it, + // we keep retrying cancel RPCs until task is finished or + // an executor tells us to stop retrying. + + // If there's no client, it means actor is not created yet. + // Retry in 1 second. + { + absl::MutexLock lock(&mu_); + RAY_LOG(DEBUG) << "a task " << task_id << " was sent to an actor. Send a cancel RPC."; + auto queue = client_queues_.find(actor_id); + RAY_CHECK(queue != client_queues_.end()); + if (!queue->second.rpc_client) { + RetryCancelTask(task_spec, recursive, 1000); + return Status::OK(); + } + + const auto &client = queue->second.rpc_client; + auto request = rpc::CancelTaskRequest(); + request.set_intended_task_id(task_spec.TaskId().Binary()); + request.set_force_kill(force_kill); + request.set_recursive(recursive); + request.set_caller_worker_id(task_spec.CallerWorkerId().Binary()); + client->CancelTask(request, + [this, task_spec, recursive, task_id]( + const Status &status, const rpc::CancelTaskReply &reply) { + RAY_LOG(DEBUG) << "CancelTask RPC response received for " + << task_spec.TaskId() << " with status " + << status.ToString(); + + // Keep retrying every 2 seconds until a task is officially + // finished. + if (!GetTaskFinisherWithoutMu().GetTaskSpec(task_id)) { + // Task is already finished. + RAY_LOG(DEBUG) << "Task " << task_spec.TaskId() + << " is finished. Stop a cancel request."; + return; + } + + if (!reply.attempt_succeeded()) { + RetryCancelTask(task_spec, recursive, 2000); + } + }); + } + + // NOTE: Currently, ray.cancel is asynchronous. + // If we want to have a better guarantee in the cancelation result + // we should make it synchronos, but that can regress the performance. + return Status::OK(); +} + } // namespace core } // namespace ray diff --git a/src/ray/core_worker/transport/direct_actor_task_submitter.h b/src/ray/core_worker/transport/direct_actor_task_submitter.h index 457294ecc6ae..48bcd7fb6c80 100644 --- a/src/ray/core_worker/transport/direct_actor_task_submitter.h +++ b/src/ray/core_worker/transport/direct_actor_task_submitter.h @@ -25,6 +25,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" +#include "ray/common/asio/asio_util.h" #include "ray/common/id.h" #include "ray/common/ray_object.h" #include "ray/core_worker/actor_creator.h" @@ -177,6 +178,56 @@ class CoreWorkerDirectActorTaskSubmitter /// \return Whether this actor is alive. bool IsActorAlive(const ActorID &actor_id) const; + /// Cancel an actor task of a given task spec. + /// + /// Asynchronous API. + /// The API is thread-safe. + /// + /// The cancelation protocol requires the coordination between + /// the caller and executor side. + /// + /// Once the task is canceled, tasks retry count becomes 0. + /// + /// The client side protocol is as follow; + /// + /// - Dependencies not resolved + /// - Cancel dep resolution and fail the object immediately. + /// - Dependencies are resolved and tasks are queued. + /// - Unqueue the entry from the queue and fail the object immediately. + /// - Tasks are sent to executor. + /// - We keep retrying cancel RPCs until the executor said it + /// succeeds (tasks were queued or executing) or the task is finished. + /// - Tasks are finished + /// - Do nothing if cancel is requested here. + /// + /// The executor side protocol is as follow; + /// + /// - Tasks not received + /// - Fail the cancel RPC. The client will retry. + /// - Tasks are queued + /// - Register the canceled tasks and fail when the task is + /// executed. + /// - Tasks are executing + /// - if async task, trigger future.cancel. Otherwise, do nothing. + /// TODO(sang): We should ideally update runtime context so that + /// users can do cooperative cancelation. + /// - Tasks are finished. + /// - We just fail the cancel RPC. We cannot distinguish this from + /// "Tasks not received" state because we don't track all finished + /// tasks. We rely on the client side stop retrying RPCs + /// when the task finishes. + /// + /// \param task_spec The task spec of a task that will be canceled. + /// \param recursive If true, it will cancel all child tasks. + /// \return True if cancel request is not needed or it will be + /// requested. False otherwise. Note that tasks could be "not" + /// canceled although the status is true because it is an + /// asynchronous API. + Status CancelTask(TaskSpecification task_spec, bool recursive); + + /// Retry the CancelTask in milliseconds. + void RetryCancelTask(TaskSpecification task_spec, bool recursive, int64_t milliseconds); + private: /// A helper function to get task finisher without holding mu_ /// We should use this function when access diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 2dead83132fa..0a1e61dff9e8 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -207,16 +207,16 @@ void CoreWorkerDirectTaskReceiver::HandleTask( } }; - auto cancel_callback = [reply, task_spec](rpc::SendReplyCallback send_reply_callback) { + auto cancel_callback = [reply, task_spec](const Status &status, + rpc::SendReplyCallback send_reply_callback) { if (task_spec.IsActorTask()) { // We consider cancellation of actor tasks to be a push task RPC failure. - send_reply_callback( - Status::Invalid("client cancelled stale rpc"), nullptr, nullptr); + send_reply_callback(status, nullptr, nullptr); } else { // We consider cancellation of normal tasks to be an in-band cancellation of a // successful RPC. reply->set_was_cancelled_before_running(true); - send_reply_callback(Status::OK(), nullptr, nullptr); + send_reply_callback(status, nullptr, nullptr); } }; @@ -289,6 +289,12 @@ void CoreWorkerDirectTaskReceiver::RunNormalTasksFromQueue() { normal_scheduling_queue_->ScheduleRequests(); } +bool CoreWorkerDirectTaskReceiver::CancelQueuedActorTask(const WorkerID &caller_worker_id, + const TaskID &task_id) { + auto it = actor_scheduling_queues_.find(caller_worker_id); + return it->second->CancelTaskIfFound(task_id); +} + bool CoreWorkerDirectTaskReceiver::CancelQueuedNormalTask(TaskID task_id) { // Look up the task to be canceled in the queue of normal tasks. If it is found and // removed successfully, return true. diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 1b93f77d2893..6c377ea21e40 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -95,6 +95,12 @@ class CoreWorkerDirectTaskReceiver { bool CancelQueuedNormalTask(TaskID task_id); + /// Cancel an actor task queued in the actor scheduling queue for caller_worker_id. + /// Return true if a task is queued or executing. False otherwise. + /// If task is not executed yet, this will guarantee the task won't be executed. + /// This API is idempotent. + bool CancelQueuedActorTask(const WorkerID &caller_worker_id, const TaskID &task_id); + void Stop(); /// Set the actor repr name for an actor. diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index ec7a7f188d11..31d54e43e3d2 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -796,6 +796,7 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec, request.set_intended_task_id(task_spec.TaskId().Binary()); request.set_force_kill(force_kill); request.set_recursive(recursive); + request.set_caller_worker_id(task_spec.CallerWorkerId().Binary()); client->CancelTask( request, [this, task_spec, scheduling_key, force_kill, recursive]( diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 48c8dda7c75f..a095c737608a 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -115,6 +115,10 @@ class CoreWorkerDirectTaskSubmitter { /// \param[in] force_kill Whether to kill the worker executing the task. Status CancelTask(TaskSpecification task_spec, bool force_kill, bool recursive); + /// Request the owner of the object ID to cancel a request. + /// It is used when a object ID is not owned by the current process. + /// We cannot cancel the task in this case because we don't have enough + /// information to cancel a task. Status CancelRemoteTask(const ObjectID &object_id, const rpc::Address &worker_addr, bool force_kill, diff --git a/src/ray/core_worker/transport/normal_scheduling_queue.cc b/src/ray/core_worker/transport/normal_scheduling_queue.cc index 20e11097cab8..93505b942325 100644 --- a/src/ray/core_worker/transport/normal_scheduling_queue.cc +++ b/src/ray/core_worker/transport/normal_scheduling_queue.cc @@ -39,7 +39,7 @@ void NormalSchedulingQueue::Add( int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, - std::function reject_request, + std::function reject_request, rpc::SendReplyCallback send_reply_callback, const std::string &concurrency_group_name, const FunctionDescriptor &function_descriptor, @@ -68,7 +68,7 @@ bool NormalSchedulingQueue::CancelTaskIfFound(TaskID task_id) { it != pending_normal_tasks_.rend(); ++it) { if (it->TaskID() == task_id) { - it->Cancel(); + it->Cancel(Status::OK()); pending_normal_tasks_.erase(std::next(it).base()); return true; } diff --git a/src/ray/core_worker/transport/normal_scheduling_queue.h b/src/ray/core_worker/transport/normal_scheduling_queue.h index d2512a0cd9dd..a2c451fcc0ca 100644 --- a/src/ray/core_worker/transport/normal_scheduling_queue.h +++ b/src/ray/core_worker/transport/normal_scheduling_queue.h @@ -41,7 +41,7 @@ class NormalSchedulingQueue : public SchedulingQueue { void Add(int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, - std::function reject_request, + std::function reject_request, rpc::SendReplyCallback send_reply_callback, const std::string &concurrency_group_name, const FunctionDescriptor &function_descriptor, diff --git a/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc b/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc index 4bc24fdccc11..acb6e96a2fae 100644 --- a/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc +++ b/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc @@ -64,7 +64,7 @@ void OutOfOrderActorSchedulingQueue::Add( int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, - std::function reject_request, + std::function reject_request, rpc::SendReplyCallback send_reply_callback, const std::string &concurrency_group_name, const ray::FunctionDescriptor &function_descriptor, @@ -78,6 +78,10 @@ void OutOfOrderActorSchedulingQueue::Add( dependencies.size() > 0, concurrency_group_name, function_descriptor); + { + absl::MutexLock lock(&mu_); + pending_task_id_to_is_canceled.emplace(task_id, false); + } if (dependencies.size() > 0) { waiter_.Wait(dependencies, [this, request = std::move(request)]() mutable { @@ -94,33 +98,68 @@ void OutOfOrderActorSchedulingQueue::Add( } bool OutOfOrderActorSchedulingQueue::CancelTaskIfFound(TaskID task_id) { - RAY_CHECK(false) << "Cannot cancel actor tasks"; - return false; + absl::MutexLock lock(&mu_); + if (pending_task_id_to_is_canceled.find(task_id) != + pending_task_id_to_is_canceled.end()) { + // Mark the task is canceled. + pending_task_id_to_is_canceled[task_id] = true; + return true; + } else { + return false; + } } /// Schedules as many requests as possible in sequence. void OutOfOrderActorSchedulingQueue::ScheduleRequests() { while (!pending_actor_tasks_.empty()) { auto request = pending_actor_tasks_.front(); + const auto task_id = request.TaskID(); if (is_asyncio_) { // Process async actor task. auto fiber = fiber_state_manager_->GetExecutor(request.ConcurrencyGroupName(), request.FunctionDescriptor()); - fiber->EnqueueFiber([request]() mutable { request.Accept(); }); + fiber->EnqueueFiber([this, request, task_id]() mutable { + AcceptRequestOrRejectIfCanceled(task_id, request); + }); } else { // Process actor tasks. RAY_CHECK(pool_manager_ != nullptr); auto pool = pool_manager_->GetExecutor(request.ConcurrencyGroupName(), request.FunctionDescriptor()); if (pool == nullptr) { - request.Accept(); + AcceptRequestOrRejectIfCanceled(task_id, request); } else { - pool->Post([request]() mutable { request.Accept(); }); + pool->Post([this, request, task_id]() mutable { + AcceptRequestOrRejectIfCanceled(task_id, request); + }); } } pending_actor_tasks_.pop_front(); } } +void OutOfOrderActorSchedulingQueue::AcceptRequestOrRejectIfCanceled( + TaskID task_id, InboundRequest &request) { + bool is_canceled = false; + { + absl::MutexLock lock(&mu_); + auto it = pending_task_id_to_is_canceled.find(task_id); + if (it != pending_task_id_to_is_canceled.end()) { + is_canceled = it->second; + } + } + + // Accept can be very long, and we shouldn't hold a lock. + if (is_canceled) { + request.Cancel( + Status::SchedulingCancelled("Task is canceled before it is scheduled.")); + } else { + request.Accept(); + } + + absl::MutexLock lock(&mu_); + pending_task_id_to_is_canceled.erase(task_id); +} + } // namespace core } // namespace ray diff --git a/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.h b/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.h index b5e8dcd4eca3..667a3b93093a 100644 --- a/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.h +++ b/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.h @@ -55,21 +55,27 @@ class OutOfOrderActorSchedulingQueue : public SchedulingQueue { void Add(int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, - std::function reject_request, + std::function reject_request, rpc::SendReplyCallback send_reply_callback, const std::string &concurrency_group_name, const ray::FunctionDescriptor &function_descriptor, TaskID task_id = TaskID::Nil(), const std::vector &dependencies = {}) override; - // We don't allow the cancellation of actor tasks, so invoking CancelTaskIfFound - // results in a fatal error. + /// Cancel the actor task in the queue. + /// Tasks are in the queue if it is either queued, or executing. + /// Return true if a task is in the queue. False otherwise. + /// This method has to be THREAD-SAFE. bool CancelTaskIfFound(TaskID task_id) override; /// Schedules as many requests as possible in sequence. void ScheduleRequests() override; private: + /// Accept the given InboundRequest or reject it if a task id is canceled via + /// CancelTaskIfFound. + void AcceptRequestOrRejectIfCanceled(TaskID task_id, InboundRequest &request); + /// The queue stores all the pending tasks. std::deque pending_actor_tasks_; /// The id of the thread that constructed this scheduling queue. @@ -84,6 +90,11 @@ class OutOfOrderActorSchedulingQueue : public SchedulingQueue { /// Whether we should enqueue requests into asyncio pool. Setting this to true /// will instantiate all tasks as fibers that can be yielded. bool is_asyncio_ = false; + /// Mutext to protect attributes used for thread safe APIs. + absl::Mutex mu_; + /// A map of actor task IDs -> is_canceled. + // Pending means tasks are queued or running. + absl::flat_hash_map pending_task_id_to_is_canceled GUARDED_BY(mu_); friend class SchedulingQueueTest; }; diff --git a/src/ray/core_worker/transport/out_of_order_actor_submit_queue.cc b/src/ray/core_worker/transport/out_of_order_actor_submit_queue.cc index d502605e02bb..71bfaf0126d3 100644 --- a/src/ray/core_worker/transport/out_of_order_actor_submit_queue.cc +++ b/src/ray/core_worker/transport/out_of_order_actor_submit_queue.cc @@ -49,6 +49,11 @@ void OutofOrderActorSubmitQueue::MarkDependencyFailed(uint64_t position) { pending_queue_.erase(position); } +void OutofOrderActorSubmitQueue::MarkTaskCanceled(uint64_t position) { + pending_queue_.erase(position); + sending_queue_.erase(position); +} + void OutofOrderActorSubmitQueue::MarkDependencyResolved(uint64_t position) { // move the task from pending_requests queue to sending_requests queue. auto it = pending_queue_.find(position); diff --git a/src/ray/core_worker/transport/out_of_order_actor_submit_queue.h b/src/ray/core_worker/transport/out_of_order_actor_submit_queue.h index 0eaad785437e..553074c52c7b 100644 --- a/src/ray/core_worker/transport/out_of_order_actor_submit_queue.h +++ b/src/ray/core_worker/transport/out_of_order_actor_submit_queue.h @@ -46,6 +46,10 @@ class OutofOrderActorSubmitQueue : public IActorSubmitQueue { void MarkDependencyFailed(uint64_t position) override; /// Make a task's dependency is resolved thus ready to send. void MarkDependencyResolved(uint64_t position) override; + // Mark a task has been canceled. + // If a task hasn't been sent yet, this API will guarantee a task won't be + // popped via PopNextTaskToSend. + void MarkTaskCanceled(uint64_t position) override; /// Clear the queue and returns all tasks ids that haven't been sent yet. std::vector ClearAllTasks() override; /// Find next task to send. diff --git a/src/ray/core_worker/transport/scheduling_queue.h b/src/ray/core_worker/transport/scheduling_queue.h index 5ea71e926177..c5133831a36b 100644 --- a/src/ray/core_worker/transport/scheduling_queue.h +++ b/src/ray/core_worker/transport/scheduling_queue.h @@ -27,15 +27,16 @@ namespace core { class SchedulingQueue { public: virtual ~SchedulingQueue() = default; - virtual void Add(int64_t seq_no, - int64_t client_processed_up_to, - std::function accept_request, - std::function reject_request, - rpc::SendReplyCallback send_reply_callback, - const std::string &concurrency_group_name, - const ray::FunctionDescriptor &function_descriptor, - TaskID task_id = TaskID::Nil(), - const std::vector &dependencies = {}) = 0; + virtual void Add( + int64_t seq_no, + int64_t client_processed_up_to, + std::function accept_request, + std::function reject_request, + rpc::SendReplyCallback send_reply_callback, + const std::string &concurrency_group_name, + const ray::FunctionDescriptor &function_descriptor, + TaskID task_id = TaskID::Nil(), + const std::vector &dependencies = {}) = 0; virtual void ScheduleRequests() = 0; virtual bool TaskQueueEmpty() const = 0; virtual size_t Size() const = 0; diff --git a/src/ray/core_worker/transport/sequential_actor_submit_queue.cc b/src/ray/core_worker/transport/sequential_actor_submit_queue.cc index 7a499893b704..fa6a8ce5baa2 100644 --- a/src/ray/core_worker/transport/sequential_actor_submit_queue.cc +++ b/src/ray/core_worker/transport/sequential_actor_submit_queue.cc @@ -41,6 +41,12 @@ void SequentialActorSubmitQueue::MarkDependencyFailed(uint64_t sequence_no) { requests.erase(sequence_no); } +void SequentialActorSubmitQueue::MarkTaskCanceled(uint64_t sequence_no) { + requests.erase(sequence_no); + // No need to clean out_of_order_completed_tasks because + // it means a task has been already submitted and finished. +} + void SequentialActorSubmitQueue::MarkDependencyResolved(uint64_t sequence_no) { auto it = requests.find(sequence_no); RAY_CHECK(it != requests.end()); diff --git a/src/ray/core_worker/transport/sequential_actor_submit_queue.h b/src/ray/core_worker/transport/sequential_actor_submit_queue.h index b0713fff813c..2a3919c6cd34 100644 --- a/src/ray/core_worker/transport/sequential_actor_submit_queue.h +++ b/src/ray/core_worker/transport/sequential_actor_submit_queue.h @@ -42,6 +42,10 @@ class SequentialActorSubmitQueue : public IActorSubmitQueue { void MarkDependencyFailed(uint64_t sequence_no) override; /// Make a task's dependency is resolved thus ready to send. void MarkDependencyResolved(uint64_t sequence_no) override; + // Mark a task has been canceled. + // If a task hasn't been sent yet, this API will guarantee a task won't be + // popped via PopNextTaskToSend. + void MarkTaskCanceled(uint64_t sequence_no) override; /// Clear the queue and returns all tasks ids that haven't been sent yet. std::vector ClearAllTasks() override; /// Find next task to send. diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 76aedea7dc7b..6ec1724adf72 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -250,6 +250,8 @@ message CancelTaskRequest { bool force_kill = 2; // Whether to recursively cancel tasks. bool recursive = 3; + // The worker ID of the caller. + bytes caller_worker_id = 4; } message CancelTaskReply {