diff --git a/python/ray/autoscaler/v2/BUILD b/python/ray/autoscaler/v2/BUILD index d7859953e077b..f0a2a061b073d 100644 --- a/python/ray/autoscaler/v2/BUILD +++ b/python/ray/autoscaler/v2/BUILD @@ -62,6 +62,7 @@ py_test( deps = ["//:ray_lib",], ) + py_test( name = "test_reconciler", size = "small", diff --git a/python/ray/autoscaler/v2/instance_manager/reconciler.py b/python/ray/autoscaler/v2/instance_manager/reconciler.py index e32f875350119..0ccb15c9773f4 100644 --- a/python/ray/autoscaler/v2/instance_manager/reconciler.py +++ b/python/ray/autoscaler/v2/instance_manager/reconciler.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Set, Tuple from ray._private.protobuf_compat import message_to_dict +from ray._private.utils import binary_to_hex from ray.autoscaler.v2.instance_manager.common import InstanceUtil from ray.autoscaler.v2.instance_manager.config import ( AutoscalingConfig, @@ -22,6 +23,7 @@ TerminateNodeError, ) from ray.autoscaler.v2.instance_manager.ray_installer import RayInstallError +from ray.autoscaler.v2.instance_manager.subscribers.ray_stopper import RayStopError from ray.autoscaler.v2.scheduler import IResourceScheduler, SchedulingRequest from ray.autoscaler.v2.schema import AutoscalerInstance, NodeType from ray.core.generated.autoscaler_pb2 import ( @@ -162,9 +164,10 @@ def reconcile( cloud_provider: ICloudInstanceProvider, ray_cluster_resource_state: ClusterResourceState, non_terminated_cloud_instances: Dict[CloudInstanceId, CloudInstance], - cloud_provider_errors: List[CloudInstanceProviderError], - ray_install_errors: List[RayInstallError], autoscaling_config: AutoscalingConfig, + cloud_provider_errors: Optional[List[CloudInstanceProviderError]] = None, + ray_install_errors: Optional[List[RayInstallError]] = None, + ray_stop_errors: Optional[List[RayStopError]] = None, _logger: Optional[logging.Logger] = None, ) -> AutoscalingState: """ @@ -189,8 +192,13 @@ def reconcile( the cloud provider. cloud_provider_errors: The errors from the cloud provider. ray_install_errors: The errors from RayInstaller. + ray_stop_errors: The errors from RayStopper. """ + cloud_provider_errors = cloud_provider_errors or [] + ray_install_errors = ray_install_errors or [] + ray_stop_errors = ray_stop_errors or [] + autoscaling_state = AutoscalingState() Reconciler._sync_from( instance_manager=instance_manager, @@ -198,6 +206,7 @@ def reconcile( non_terminated_cloud_instances=non_terminated_cloud_instances, cloud_provider_errors=cloud_provider_errors, ray_install_errors=ray_install_errors, + ray_stop_errors=ray_stop_errors, ) Reconciler._step_next( @@ -219,6 +228,7 @@ def _sync_from( non_terminated_cloud_instances: Dict[CloudInstanceId, CloudInstance], cloud_provider_errors: List[CloudInstanceProviderError], ray_install_errors: List[RayInstallError], + ray_stop_errors: List[RayStopError], ): """ Reconcile the instance states of the instance manager from external states like @@ -254,6 +264,9 @@ def _sync_from( instance to RAY_STOPPED. 7. * -> RAY_INSTALL_FAILED: When there's an error from RayInstaller. + 8. RAY_STOP_REQUESTED -> RAY_RUNNING: + When requested to stop ray, but failed to stop/drain the ray node + (e.g. idle termination drain rejected by the node). Args: instance_manager: The instance manager to reconcile. @@ -282,6 +295,8 @@ def _sync_from( Reconciler._handle_ray_install_failed(instance_manager, ray_install_errors) + Reconciler._handle_ray_stop_failed(instance_manager, ray_stop_errors, ray_nodes) + @staticmethod def _step_next( autoscaling_state: AutoscalingState, @@ -468,6 +483,69 @@ def _try_resolve_pending_allocation( # No update. return None + @staticmethod + def _handle_ray_stop_failed( + instance_manager: InstanceManager, + ray_stop_errors: List[RayStopError], + ray_nodes: List[NodeState], + ): + """ + The instance requested to stop ray, but failed to stop/drain the ray node. + E.g. connection errors, idle termination drain rejected by the node. + + We will transition the instance back to RAY_RUNNING. + + Args: + instance_manager: The instance manager to reconcile. + ray_stop_errors: The errors from RayStopper. + + """ + instances, version = Reconciler._get_im_instances(instance_manager) + updates = {} + + ray_stop_errors_by_instance_id = { + error.im_instance_id: error for error in ray_stop_errors + } + + ray_nodes_by_ray_node_id = {binary_to_hex(n.node_id): n for n in ray_nodes} + + ray_stop_requested_instances = { + instance.instance_id: instance + for instance in instances + if instance.status == IMInstance.RAY_STOP_REQUESTED + } + + for instance_id, instance in ray_stop_requested_instances.items(): + stop_error = ray_stop_errors_by_instance_id.get(instance_id) + if not stop_error: + continue + + assert instance.node_id + ray_node = ray_nodes_by_ray_node_id.get(instance.node_id) + assert ray_node is not None and ray_node.status in [ + NodeStatus.RUNNING, + NodeStatus.IDLE, + ], ( + "There should be a running ray node for instance with ray stop " + "requested failed." + ) + + updates[instance_id] = IMInstanceUpdateEvent( + instance_id=instance_id, + new_instance_status=IMInstance.RAY_RUNNING, + details="Failed to stop/drain ray.", + ray_node_id=instance.node_id, + ) + logger.debug( + "Updating {}({}) with {}".format( + instance_id, + IMInstance.InstanceStatus.Name(instance.status), + message_to_dict(updates[instance_id]), + ) + ) + + Reconciler._update_instance_manager(instance_manager, version, updates) + @staticmethod def _handle_ray_install_failed( instance_manager: InstanceManager, ray_install_errors: List[RayInstallError] @@ -676,7 +754,7 @@ def _handle_ray_status_transition( else: # This should only happen to a ray node that's not managed by us. logger.warning( - f"Ray node {n.node_id.decode()} has no instance id. " + f"Ray node {binary_to_hex(n.node_id)} has no instance id. " "This only happens to a ray node that's not managed by autoscaler. " "If not, please file a bug at https://github.com/ray-project/ray" ) @@ -687,8 +765,8 @@ def _handle_ray_status_transition( # or we haven't discovered the instance yet. There's nothing # much we could do here. logger.info( - f"Ray node {ray_node.node_id.decode()} has no matching instance in " - f"instance manager with cloud instance id={cloud_instance_id}." + f"Ray node {binary_to_hex(ray_node.node_id)} has no matching " + f"instance with cloud instance id={cloud_instance_id}." ) continue @@ -703,8 +781,8 @@ def _handle_ray_status_transition( new_instance_status=reconciled_im_status, details="Reconciled from ray node status " f"{NodeStatus.Name(ray_node.status)} " - f"for ray node {ray_node.node_id.decode()}", - ray_node_id=ray_node.node_id.decode(), + f"for ray node {binary_to_hex(ray_node.node_id)}", + ray_node_id=binary_to_hex(ray_node.node_id), ) logger.debug( "Updating {}({}) with {}.".format( @@ -1065,7 +1143,7 @@ def _scale_cluster( autoscaler_instances = [] ray_nodes_by_id = { - node.node_id.decode(): node for node in ray_state.node_states + binary_to_hex(node.node_id): node for node in ray_state.node_states } for im_instance in im_instances: diff --git a/python/ray/autoscaler/v2/instance_manager/subscribers/ray_stopper.py b/python/ray/autoscaler/v2/instance_manager/subscribers/ray_stopper.py new file mode 100644 index 0000000000000..1723dd23555f7 --- /dev/null +++ b/python/ray/autoscaler/v2/instance_manager/subscribers/ray_stopper.py @@ -0,0 +1,140 @@ +import logging +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from queue import Queue +from typing import List + +from ray._private.utils import hex_to_binary +from ray._raylet import GcsClient +from ray.autoscaler.v2.instance_manager.instance_manager import ( + InstanceUpdatedSubscriber, +) +from ray.core.generated.autoscaler_pb2 import DrainNodeReason +from ray.core.generated.instance_manager_pb2 import ( + Instance, + InstanceUpdateEvent, + TerminationRequest, +) + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class RayStopError: + # Instance manager's instance id. + im_instance_id: str + + +class RayStopper(InstanceUpdatedSubscriber): + """RayStopper is responsible for stopping ray on instances. + + It will drain the ray node if it's for idle termination. + For other terminations, it will stop the ray node. (e.g. scale down, etc.) + + If any failures happen when stopping/draining the node, we will not retry + and rely on the reconciler to handle the failure. + + TODO: we could also surface the errors back to the reconciler for + quicker failure detection. + + """ + + def __init__(self, gcs_client: GcsClient, error_queue: Queue) -> None: + self._gcs_client = gcs_client + self._error_queue = error_queue + self._executor = ThreadPoolExecutor(max_workers=1) + + def notify(self, events: List[InstanceUpdateEvent]) -> None: + for event in events: + if event.new_instance_status == Instance.RAY_STOP_REQUESTED: + fut = self._executor.submit(self._stop_or_drain_ray, event) + + def _log_on_error(fut): + try: + fut.result() + except Exception: + logger.exception("Error stopping/drain ray.") + + fut.add_done_callback(_log_on_error) + + def _stop_or_drain_ray(self, event: InstanceUpdateEvent) -> None: + """ + Stops or drains the ray node based on the termination request. + """ + assert event.HasField("termination_request"), "Termination request is required." + termination_request = event.termination_request + ray_node_id = termination_request.ray_node_id + instance_id = event.instance_id + + if termination_request.cause == TerminationRequest.Cause.IDLE: + reason = DrainNodeReason.DRAIN_NODE_REASON_IDLE_TERMINATION + reason_str = "Termination of node that's idle for {} seconds.".format( + termination_request.idle_time_ms / 1000 + ) + self._drain_ray_node( + self._gcs_client, + self._error_queue, + ray_node_id, + instance_id, + reason, + reason_str, + ) + return + + # If it's not an idle termination, we stop the ray node. + self._stop_ray_node( + self._gcs_client, self._error_queue, ray_node_id, instance_id + ) + + @staticmethod + def _drain_ray_node( + gcs_client: GcsClient, + error_queue: Queue, + ray_node_id: str, + instance_id: str, + reason: DrainNodeReason, + reason_str: str, + ): + """ + Drains the ray node. + + Args: + gcs_client: The gcs client to use. + ray_node_id: The ray node id to drain. + reason: The reason to drain the node. + reason_str: The reason message to drain the node. + """ + accepted = gcs_client.drain_node( + node_id=ray_node_id, + reason=reason, + reason_message=reason_str, + # TODO: we could probably add a deadline here that's derived + # from the stuck instance reconcilation configs. + deadline_timestamp_ms=0, + ) + logger.info(f"Draining ray on {ray_node_id}(success={accepted}): {reason_str}") + if not accepted: + error_queue.put_nowait(RayStopError(im_instance_id=instance_id)) + + @staticmethod + def _stop_ray_node( + gcs_client: GcsClient, + error_queue: Queue, + ray_node_id: str, + instance_id: str, + ): + """ + Stops the ray node. + + Args: + gcs_client: The gcs client to use. + ray_node_id: The ray node id to stop. + """ + drained = gcs_client.drain_nodes(node_ids=[hex_to_binary(ray_node_id)]) + success = len(drained) > 0 + logger.info( + f"Stopping ray on {ray_node_id}(instance={instance_id}): success={success})" + ) + + if not success: + error_queue.put_nowait(RayStopError(im_instance_id=instance_id)) diff --git a/python/ray/autoscaler/v2/tests/test_reconciler.py b/python/ray/autoscaler/v2/tests/test_reconciler.py index b8cd581ac41d3..5d027b23cb132 100644 --- a/python/ray/autoscaler/v2/tests/test_reconciler.py +++ b/python/ray/autoscaler/v2/tests/test_reconciler.py @@ -8,6 +8,7 @@ import mock from mock import MagicMock +from ray._private.utils import binary_to_hex from ray.autoscaler.v2.instance_manager.config import InstanceReconcileConfig from ray.autoscaler.v2.instance_manager.instance_manager import InstanceManager from ray.autoscaler.v2.instance_manager.instance_storage import InstanceStorage @@ -19,6 +20,7 @@ from ray.autoscaler.v2.instance_manager.ray_installer import RayInstallError from ray.autoscaler.v2.instance_manager.reconciler import Reconciler, logger from ray.autoscaler.v2.instance_manager.storage import InMemoryStorage +from ray.autoscaler.v2.instance_manager.subscribers.ray_stopper import RayStopError from ray.autoscaler.v2.scheduler import IResourceScheduler, SchedulingReply from ray.autoscaler.v2.tests.util import MockSubscriber, create_instance from ray.core.generated.autoscaler_pb2 import ( @@ -388,7 +390,7 @@ def test_ray_reconciler_new_ray(setup): instances, _ = instance_storage.get_instances() assert len(instances) == 1 assert instances["i-1"].status == Instance.RAY_RUNNING - assert instances["i-1"].node_id == "r-1" + assert instances["i-1"].node_id == binary_to_hex(b"r-1") @staticmethod def test_ray_reconciler_already_ray_running(setup): @@ -799,6 +801,65 @@ def test_stuck_instances_ray_stop_requested(mock_time_ns, setup): assert instances["no-update"].status == cur_status assert instances["updated"].status == Instance.RAY_RUNNING + @staticmethod + @mock.patch("time.time_ns") + def test_ray_stop_requested_fail(mock_time_ns, setup): + # Test that the instance should be transitioned to RAY_RUNNING + # when the ray stop request fails. + + instance_manager, instance_storage, _ = setup + mock_time_ns.return_value = 10 * s_to_ns + + instances = [ + create_instance( + "i-1", + status=Instance.RAY_STOP_REQUESTED, + ray_node_id=binary_to_hex(b"r-1"), + cloud_instance_id="c-1", + status_times=[(Instance.RAY_STOP_REQUESTED, 10 * s_to_ns)], + ), + create_instance( + "i-2", + status=Instance.RAY_STOP_REQUESTED, + ray_node_id=binary_to_hex(b"r-2"), + cloud_instance_id="c-2", + status_times=[(Instance.RAY_STOP_REQUESTED, 10 * s_to_ns)], + ), + ] + + ray_nodes = [ + NodeState(node_id=b"r-1", status=NodeStatus.RUNNING, instance_id="c-1"), + NodeState(node_id=b"r-2", status=NodeStatus.RUNNING, instance_id="c-2"), + ] + + ray_stop_errors = [ + RayStopError(im_instance_id="i-1"), + ] + + cloud_instances = { + "c-1": CloudInstance("c-1", "type-1", "", True, NodeKind.WORKER), + "c-2": CloudInstance("c-2", "type-2", "", True, NodeKind.WORKER), + } + + TestReconciler._add_instances(instance_storage, instances) + + Reconciler.reconcile( + instance_manager=instance_manager, + scheduler=MockScheduler(), + cloud_provider=MagicMock(), + ray_cluster_resource_state=ClusterResourceState(node_states=ray_nodes), + non_terminated_cloud_instances=cloud_instances, + cloud_provider_errors=[], + ray_install_errors=[], + ray_stop_errors=ray_stop_errors, + autoscaling_config=MockAutoscalingConfig(), + ) + + instances, _ = instance_storage.get_instances() + assert len(instances) == 2 + assert instances["i-1"].status == Instance.RAY_RUNNING + assert instances["i-2"].status == Instance.RAY_STOP_REQUESTED + @staticmethod @mock.patch("time.time_ns") @pytest.mark.parametrize( diff --git a/python/ray/autoscaler/v2/tests/test_subscribers.py b/python/ray/autoscaler/v2/tests/test_subscribers.py index fd6bea4990408..aca119aa4c740 100644 --- a/python/ray/autoscaler/v2/tests/test_subscribers.py +++ b/python/ray/autoscaler/v2/tests/test_subscribers.py @@ -1,16 +1,151 @@ # coding: utf-8 import os import sys +from queue import Queue import pytest import mock from ray._private.test_utils import wait_for_condition +from ray._private.utils import binary_to_hex, hex_to_binary from ray.autoscaler.v2.instance_manager.subscribers.cloud_instance_updater import ( CloudInstanceUpdater, ) -from ray.core.generated.instance_manager_pb2 import Instance, InstanceUpdateEvent +from ray.autoscaler.v2.instance_manager.subscribers.ray_stopper import ( # noqa + RayStopper, +) +from ray.core.generated.autoscaler_pb2 import DrainNodeReason +from ray.core.generated.instance_manager_pb2 import ( + Instance, + InstanceUpdateEvent, + TerminationRequest, +) + + +class TestRayStopper: + def test_no_op(self): + mock_gcs_client = mock.MagicMock() + ray_stopper = RayStopper(gcs_client=mock_gcs_client, error_queue=Queue()) + + ray_stopper.notify( + [ + InstanceUpdateEvent( + instance_id="test_id", + new_instance_status=Instance.REQUESTED, + ) + ] + ) + assert mock_gcs_client.drain_node.call_count == 0 + + @pytest.mark.parametrize( + "drain_accepted", + [True, False], + ids=["drain_accepted", "drain_rejected"], + ) + def test_idle_termination(self, drain_accepted): + mock_gcs_client = mock.MagicMock() + mock_gcs_client.drain_node.return_value = drain_accepted + error_queue = Queue() + ray_stopper = RayStopper(gcs_client=mock_gcs_client, error_queue=error_queue) + + ray_stopper.notify( + [ + InstanceUpdateEvent( + instance_id="test_id", + new_instance_status=Instance.RAY_STOP_REQUESTED, + termination_request=TerminationRequest( + cause=TerminationRequest.Cause.IDLE, + idle_time_ms=1000, + ray_node_id="0000", + ), + ) + ] + ) + + def verify(): + mock_gcs_client.drain_node.assert_has_calls( + [ + mock.call( + node_id="0000", + reason=DrainNodeReason.DRAIN_NODE_REASON_IDLE_TERMINATION, + reason_message=( + "Termination of node that's idle for 1.0 seconds." + ), + deadline_timestamp_ms=0, + ) + ] + ) + + if drain_accepted: + assert error_queue.empty() + else: + error = error_queue.get_nowait() + assert error.im_instance_id == "test_id" + + return True + + wait_for_condition(verify) + + @pytest.mark.parametrize( + "stop_accepted", + [True, False], + ids=["stop_accepted", "stop_rejected"], + ) + def test_preemption(self, stop_accepted): + mock_gcs_client = mock.MagicMock() + mock_gcs_client.drain_nodes.return_value = [0] if stop_accepted else [] + error_queue = Queue() + ray_stopper = RayStopper(gcs_client=mock_gcs_client, error_queue=error_queue) + + ray_stopper.notify( + [ + InstanceUpdateEvent( + instance_id="i-1", + new_instance_status=Instance.RAY_STOP_REQUESTED, + termination_request=TerminationRequest( + cause=TerminationRequest.Cause.MAX_NUM_NODE_PER_TYPE, + max_num_nodes_per_type=10, + ray_node_id=binary_to_hex(hex_to_binary(b"1111")), + ), + ), + InstanceUpdateEvent( + instance_id="i-2", + new_instance_status=Instance.RAY_STOP_REQUESTED, + termination_request=TerminationRequest( + cause=TerminationRequest.Cause.MAX_NUM_NODES, + max_num_nodes=100, + ray_node_id=binary_to_hex(hex_to_binary(b"2222")), + ), + ), + ] + ) + + def verify(): + mock_gcs_client.drain_nodes.assert_has_calls( + [ + mock.call( + node_ids=[hex_to_binary(b"1111")], + ), + mock.call( + node_ids=[hex_to_binary(b"2222")], + ), + ] + ) + + if stop_accepted: + assert error_queue.empty() + else: + error_in_ids = set() + while not error_queue.empty(): + error = error_queue.get_nowait() + error_in_ids.add(error.im_instance_id) + + assert error_in_ids == {"i-1", "i-2"} + + return True + + wait_for_condition(verify) class TestCloudInstanceUpdater: