From 40c4148d4f065802f7888017ef26c2213b45341b Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sat, 20 Oct 2018 22:56:29 -0700 Subject: [PATCH] Cluster Utilities for Fault Tolerance Tests (#3008) --- .travis.yml | 2 + python/ray/test/cluster_utils.py | 201 +++++++++++++++++++++++++++++++ python/ray/worker.py | 4 - test/multi_node_test_2.py | 72 +++++++++++ 4 files changed, 275 insertions(+), 4 deletions(-) create mode 100644 python/ray/test/cluster_utils.py create mode 100644 test/multi_node_test_2.py diff --git a/.travis.yml b/.travis.yml index bd0fd929b7334..70f548f83bf89 100644 --- a/.travis.yml +++ b/.travis.yml @@ -146,6 +146,7 @@ matrix: - python -m pytest -v test/stress_tests.py - pytest test/component_failures_test.py - python test/multi_node_test.py + - python -m pytest -v test/multi_node_test_2.py - python -m pytest -v test/recursion_test.py - pytest test/monitor_test.py - python -m pytest -v test/cython_test.py @@ -223,6 +224,7 @@ script: - python -m pytest -v test/stress_tests.py - python -m pytest -v test/component_failures_test.py - python test/multi_node_test.py + - python -m pytest -v test/multi_node_test_2.py - python -m pytest -v test/recursion_test.py - python -m pytest -v test/monitor_test.py - python -m pytest -v test/cython_test.py diff --git a/python/ray/test/cluster_utils.py b/python/ray/test/cluster_utils.py new file mode 100644 index 0000000000000..7e7a82d67c6e7 --- /dev/null +++ b/python/ray/test/cluster_utils.py @@ -0,0 +1,201 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import time + +import ray +import ray.services as services + +logger = logging.getLogger(__name__) + + +class Cluster(object): + def __init__(self, + initialize_head=False, + connect=False, + head_node_args=None): + """Initializes the cluster. + + Args: + initialize_head (bool): Automatically start a Ray cluster + by initializing the head node. Defaults to False. + connect (bool): If `initialize_head=True` and `connect=True`, + ray.init will be called with the redis address of this cluster + passed in. + head_node_args (kwargs): Arguments to be passed into + `start_ray_head` via `self.add_node`. + """ + self.head_node = None + self.worker_nodes = {} + self.redis_address = None + if not initialize_head and connect: + raise RuntimeError("Cannot connect to uninitialized cluster.") + + if initialize_head: + head_node_args = head_node_args or {} + self.add_node(**head_node_args) + if connect: + ray.init(redis_address=self.redis_address) + + def add_node(self, **override_kwargs): + """Adds a node to the local Ray Cluster. + + All nodes are by default started with the following settings: + cleanup=True, + use_raylet=True, + resources={"CPU": 1}, + object_store_memory=100 * (2**20) # 100 MB + + Args: + override_kwargs: Keyword arguments used in `start_ray_head` + and `start_ray_node`. Overrides defaults. + + Returns: + Node object of the added Ray node. + """ + node_kwargs = dict( + cleanup=True, + use_raylet=True, + resources={"CPU": 1}, + object_store_memory=100 * (2**20) # 100 MB + ) + node_kwargs.update(override_kwargs) + + if self.head_node is None: + address_info = services.start_ray_head( + node_ip_address=services.get_node_ip_address(), + include_webui=False, + **node_kwargs) + self.redis_address = address_info["redis_address"] + # TODO(rliaw): Find a more stable way than modifying global state. + process_dict_copy = services.all_processes.copy() + for key in services.all_processes: + services.all_processes[key] = [] + node = Node(process_dict_copy) + self.head_node = node + else: + address_info = services.start_ray_node( + services.get_node_ip_address(), self.redis_address, + **node_kwargs) + # TODO(rliaw): Find a more stable way than modifying global state. + process_dict_copy = services.all_processes.copy() + for key in services.all_processes: + services.all_processes[key] = [] + node = Node(process_dict_copy) + self.worker_nodes[node] = address_info + logging.info("Starting Node with raylet socket {}".format( + address_info["raylet_socket_names"])) + + return node + + def remove_node(self, node): + """Kills all processes associated with worker node. + + Args: + node (Node): Worker node of which all associated processes + will be removed. + """ + if self.head_node == node: + self.head_node.kill_all_processes() + self.head_node = None + # TODO(rliaw): Do we need to kill all worker processes? + else: + node.kill_all_processes() + self.worker_nodes.pop(node) + + assert not node.any_processes_alive(), ( + "There are zombie processes left over after killing.") + + def wait_for_nodes(self, retries=20): + """Waits for all nodes to be registered with global state. + + Args: + retries (int): Number of times to retry checking client table. + """ + for i in range(retries): + if not ray.is_initialized() or not self._check_registered_nodes(): + time.sleep(0.3) + else: + break + + def _check_registered_nodes(self): + registered = len([ + client for client in ray.global_state.client_table() + if client["IsInsertion"] + ]) + expected = len(self.list_all_nodes()) + if registered == expected: + logger.info("All nodes registered as expected.") + else: + logger.info("Currently registering {} but expecting {}".format( + registered, expected)) + return registered == expected + + def list_all_nodes(self): + """Lists all nodes. + + TODO(rliaw): What is the desired behavior if a head node + dies before worker nodes die? + + Returns: + List of all nodes, including the head node. + """ + nodes = list(self.worker_nodes) + if self.head_node: + nodes = [self.head_node] + nodes + return nodes + + def shutdown(self): + # We create a list here as a copy because `remove_node` + # modifies `self.worker_nodes`. + all_nodes = list(self.worker_nodes) + for node in all_nodes: + self.remove_node(node) + self.remove_node(self.head_node) + + +class Node(object): + """Abstraction for a Ray node.""" + + def __init__(self, process_dict): + # TODO(rliaw): Is there a unique identifier for a node? + self.process_dict = process_dict + + def kill_plasma_store(self): + self.process_dict[services.PROCESS_TYPE_PLASMA_STORE][0].kill() + self.process_dict[services.PROCESS_TYPE_PLASMA_STORE][0].wait() + + def kill_raylet(self): + self.process_dict[services.PROCESS_TYPE_RAYLET][0].kill() + self.process_dict[services.PROCESS_TYPE_RAYLET][0].wait() + + def kill_log_monitor(self): + self.process_dict["log_monitor"][0].kill() + self.process_dict["log_monitor"][0].wait() + + def kill_all_processes(self): + for process_name, process_list in self.process_dict.items(): + logger.info("Killing all {}(s)".format(process_name)) + for process in process_list: + process.kill() + + for process_name, process_list in self.process_dict.items(): + logger.info("Waiting all {}(s)".format(process_name)) + for process in process_list: + process.wait() + + def live_processes(self): + return [(p_name, proc) for p_name, p_list in self.process_dict.items() + for proc in p_list if proc.poll() is None] + + def dead_processes(self): + return [(p_name, proc) for p_name, p_list in self.process_dict.items() + for proc in p_list if proc.poll() is not None] + + def any_processes_alive(self): + return any(self.live_processes()) + + def all_processes_alive(self): + return not any(self.dead_processes()) diff --git a/python/ray/worker.py b/python/ray/worker.py index 7049b1f3a4297..e19c433753b73 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1837,10 +1837,6 @@ def shutdown(worker=global_worker): worker.plasma_client.disconnect() if worker.mode == SCRIPT_MODE: - # If this is a driver, push the finish time to Redis and clean up any - # other services that were started with the driver. - worker.redis_client.hmset(b"Drivers:" + worker.worker_id, - {"end_time": time.time()}) services.cleanup() else: # If this is not a driver, make sure there are no orphan processes, diff --git a/test/multi_node_test_2.py b/test/multi_node_test_2.py new file mode 100644 index 0000000000000..bb86bb2a7f539 --- /dev/null +++ b/test/multi_node_test_2.py @@ -0,0 +1,72 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import os +import pytest + +import ray +import ray.services as services +from ray.test.cluster_utils import Cluster + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def start_connected_cluster(): + # Start the Ray processes. + g = Cluster(initialize_head=True, connect=True) + yield g + # The code after the yield will run as teardown code. + ray.shutdown() + g.shutdown() + + +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +def test_cluster(): + """Basic test for adding and removing nodes in cluster.""" + g = Cluster(initialize_head=False) + node = g.add_node() + node2 = g.add_node() + assert node.all_processes_alive() + assert node2.all_processes_alive() + g.remove_node(node2) + g.remove_node(node) + assert not any(node.any_processes_alive() for node in g.list_all_nodes()) + + +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +def test_wait_for_nodes(start_connected_cluster): + """Unit test for `Cluster.wait_for_nodes`. + + Adds 4 workers, waits, then removes 4 workers, waits, + then adds 1 worker, waits, and removes 1 worker, waits. + """ + cluster = start_connected_cluster + workers = [cluster.add_node() for i in range(4)] + cluster.wait_for_nodes() + [cluster.remove_node(w) for w in workers] + cluster.wait_for_nodes() + worker2 = cluster.add_node() + cluster.wait_for_nodes() + cluster.remove_node(worker2) + cluster.wait_for_nodes() + + +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +def test_worker_plasma_store_failure(start_connected_cluster): + cluster = start_connected_cluster + worker = cluster.add_node() + cluster.wait_for_nodes() + # Log monitor doesn't die for some reason + worker.kill_log_monitor() + worker.kill_plasma_store() + worker.process_dict[services.PROCESS_TYPE_RAYLET][0].wait() + assert not worker.any_processes_alive(), worker.live_processes()