Skip to content

Commit

Permalink
Cluster Utilities for Fault Tolerance Tests (ray-project#3008)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw authored and robertnishihara committed Oct 21, 2018
1 parent a4db5bb commit 40c4148
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ matrix:
- python -m pytest -v test/stress_tests.py
- pytest test/component_failures_test.py
- python test/multi_node_test.py
- python -m pytest -v test/multi_node_test_2.py
- python -m pytest -v test/recursion_test.py
- pytest test/monitor_test.py
- python -m pytest -v test/cython_test.py
Expand Down Expand Up @@ -223,6 +224,7 @@ script:
- python -m pytest -v test/stress_tests.py
- python -m pytest -v test/component_failures_test.py
- python test/multi_node_test.py
- python -m pytest -v test/multi_node_test_2.py
- python -m pytest -v test/recursion_test.py
- python -m pytest -v test/monitor_test.py
- python -m pytest -v test/cython_test.py
Expand Down
201 changes: 201 additions & 0 deletions python/ray/test/cluster_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import time

import ray
import ray.services as services

logger = logging.getLogger(__name__)


class Cluster(object):
def __init__(self,
initialize_head=False,
connect=False,
head_node_args=None):
"""Initializes the cluster.
Args:
initialize_head (bool): Automatically start a Ray cluster
by initializing the head node. Defaults to False.
connect (bool): If `initialize_head=True` and `connect=True`,
ray.init will be called with the redis address of this cluster
passed in.
head_node_args (kwargs): Arguments to be passed into
`start_ray_head` via `self.add_node`.
"""
self.head_node = None
self.worker_nodes = {}
self.redis_address = None
if not initialize_head and connect:
raise RuntimeError("Cannot connect to uninitialized cluster.")

if initialize_head:
head_node_args = head_node_args or {}
self.add_node(**head_node_args)
if connect:
ray.init(redis_address=self.redis_address)

def add_node(self, **override_kwargs):
"""Adds a node to the local Ray Cluster.
All nodes are by default started with the following settings:
cleanup=True,
use_raylet=True,
resources={"CPU": 1},
object_store_memory=100 * (2**20) # 100 MB
Args:
override_kwargs: Keyword arguments used in `start_ray_head`
and `start_ray_node`. Overrides defaults.
Returns:
Node object of the added Ray node.
"""
node_kwargs = dict(
cleanup=True,
use_raylet=True,
resources={"CPU": 1},
object_store_memory=100 * (2**20) # 100 MB
)
node_kwargs.update(override_kwargs)

if self.head_node is None:
address_info = services.start_ray_head(
node_ip_address=services.get_node_ip_address(),
include_webui=False,
**node_kwargs)
self.redis_address = address_info["redis_address"]
# TODO(rliaw): Find a more stable way than modifying global state.
process_dict_copy = services.all_processes.copy()
for key in services.all_processes:
services.all_processes[key] = []
node = Node(process_dict_copy)
self.head_node = node
else:
address_info = services.start_ray_node(
services.get_node_ip_address(), self.redis_address,
**node_kwargs)
# TODO(rliaw): Find a more stable way than modifying global state.
process_dict_copy = services.all_processes.copy()
for key in services.all_processes:
services.all_processes[key] = []
node = Node(process_dict_copy)
self.worker_nodes[node] = address_info
logging.info("Starting Node with raylet socket {}".format(
address_info["raylet_socket_names"]))

return node

def remove_node(self, node):
"""Kills all processes associated with worker node.
Args:
node (Node): Worker node of which all associated processes
will be removed.
"""
if self.head_node == node:
self.head_node.kill_all_processes()
self.head_node = None
# TODO(rliaw): Do we need to kill all worker processes?
else:
node.kill_all_processes()
self.worker_nodes.pop(node)

assert not node.any_processes_alive(), (
"There are zombie processes left over after killing.")

def wait_for_nodes(self, retries=20):
"""Waits for all nodes to be registered with global state.
Args:
retries (int): Number of times to retry checking client table.
"""
for i in range(retries):
if not ray.is_initialized() or not self._check_registered_nodes():
time.sleep(0.3)
else:
break

def _check_registered_nodes(self):
registered = len([
client for client in ray.global_state.client_table()
if client["IsInsertion"]
])
expected = len(self.list_all_nodes())
if registered == expected:
logger.info("All nodes registered as expected.")
else:
logger.info("Currently registering {} but expecting {}".format(
registered, expected))
return registered == expected

def list_all_nodes(self):
"""Lists all nodes.
TODO(rliaw): What is the desired behavior if a head node
dies before worker nodes die?
Returns:
List of all nodes, including the head node.
"""
nodes = list(self.worker_nodes)
if self.head_node:
nodes = [self.head_node] + nodes
return nodes

def shutdown(self):
# We create a list here as a copy because `remove_node`
# modifies `self.worker_nodes`.
all_nodes = list(self.worker_nodes)
for node in all_nodes:
self.remove_node(node)
self.remove_node(self.head_node)


class Node(object):
"""Abstraction for a Ray node."""

def __init__(self, process_dict):
# TODO(rliaw): Is there a unique identifier for a node?
self.process_dict = process_dict

def kill_plasma_store(self):
self.process_dict[services.PROCESS_TYPE_PLASMA_STORE][0].kill()
self.process_dict[services.PROCESS_TYPE_PLASMA_STORE][0].wait()

def kill_raylet(self):
self.process_dict[services.PROCESS_TYPE_RAYLET][0].kill()
self.process_dict[services.PROCESS_TYPE_RAYLET][0].wait()

def kill_log_monitor(self):
self.process_dict["log_monitor"][0].kill()
self.process_dict["log_monitor"][0].wait()

def kill_all_processes(self):
for process_name, process_list in self.process_dict.items():
logger.info("Killing all {}(s)".format(process_name))
for process in process_list:
process.kill()

for process_name, process_list in self.process_dict.items():
logger.info("Waiting all {}(s)".format(process_name))
for process in process_list:
process.wait()

def live_processes(self):
return [(p_name, proc) for p_name, p_list in self.process_dict.items()
for proc in p_list if proc.poll() is None]

def dead_processes(self):
return [(p_name, proc) for p_name, p_list in self.process_dict.items()
for proc in p_list if proc.poll() is not None]

def any_processes_alive(self):
return any(self.live_processes())

def all_processes_alive(self):
return not any(self.dead_processes())
4 changes: 0 additions & 4 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,10 +1837,6 @@ def shutdown(worker=global_worker):
worker.plasma_client.disconnect()

if worker.mode == SCRIPT_MODE:
# If this is a driver, push the finish time to Redis and clean up any
# other services that were started with the driver.
worker.redis_client.hmset(b"Drivers:" + worker.worker_id,
{"end_time": time.time()})
services.cleanup()
else:
# If this is not a driver, make sure there are no orphan processes,
Expand Down
72 changes: 72 additions & 0 deletions test/multi_node_test_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import os
import pytest

import ray
import ray.services as services
from ray.test.cluster_utils import Cluster

logger = logging.getLogger(__name__)


@pytest.fixture
def start_connected_cluster():
# Start the Ray processes.
g = Cluster(initialize_head=True, connect=True)
yield g
# The code after the yield will run as teardown code.
ray.shutdown()
g.shutdown()


@pytest.mark.skipif(
os.environ.get("RAY_USE_XRAY") != "1",
reason="This test only works with xray.")
def test_cluster():
"""Basic test for adding and removing nodes in cluster."""
g = Cluster(initialize_head=False)
node = g.add_node()
node2 = g.add_node()
assert node.all_processes_alive()
assert node2.all_processes_alive()
g.remove_node(node2)
g.remove_node(node)
assert not any(node.any_processes_alive() for node in g.list_all_nodes())


@pytest.mark.skipif(
os.environ.get("RAY_USE_XRAY") != "1",
reason="This test only works with xray.")
def test_wait_for_nodes(start_connected_cluster):
"""Unit test for `Cluster.wait_for_nodes`.
Adds 4 workers, waits, then removes 4 workers, waits,
then adds 1 worker, waits, and removes 1 worker, waits.
"""
cluster = start_connected_cluster
workers = [cluster.add_node() for i in range(4)]
cluster.wait_for_nodes()
[cluster.remove_node(w) for w in workers]
cluster.wait_for_nodes()
worker2 = cluster.add_node()
cluster.wait_for_nodes()
cluster.remove_node(worker2)
cluster.wait_for_nodes()


@pytest.mark.skipif(
os.environ.get("RAY_USE_XRAY") != "1",
reason="This test only works with xray.")
def test_worker_plasma_store_failure(start_connected_cluster):
cluster = start_connected_cluster
worker = cluster.add_node()
cluster.wait_for_nodes()
# Log monitor doesn't die for some reason
worker.kill_log_monitor()
worker.kill_plasma_store()
worker.process_dict[services.PROCESS_TYPE_RAYLET][0].wait()
assert not worker.any_processes_alive(), worker.live_processes()

0 comments on commit 40c4148

Please sign in to comment.