forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_node_failure.py
117 lines (95 loc) · 3.58 KB
/
test_node_failure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# This workload tests RLlib's ability to recover from failing workers nodes
import time
import unittest
import ray
from ray._private.test_utils import get_other_nodes
from ray.cluster_utils import Cluster
from ray.experimental.state.api import list_actors
from ray.rllib.algorithms.ppo import PPO, PPOConfig
num_redis_shards = 5
redis_max_memory = 10**8
object_store_memory = 10**8
num_nodes = 3
assert (
num_nodes * object_store_memory + num_redis_shards * redis_max_memory
< ray._private.utils.get_system_memory() / 2
), (
"Make sure there is enough memory on this machine to run this "
"workload. We divide the system memory by 2 to provide a buffer."
)
class NodeFailureTests(unittest.TestCase):
def setUp(self):
# Simulate a cluster on one machine.
self.cluster = Cluster()
for i in range(num_nodes):
self.cluster.add_node(
redis_port=6379 if i == 0 else None,
num_redis_shards=num_redis_shards if i == 0 else None,
num_cpus=2,
num_gpus=0,
object_store_memory=object_store_memory,
redis_max_memory=redis_max_memory,
dashboard_host="0.0.0.0",
)
self.cluster.wait_for_nodes()
ray.init(address=self.cluster.address)
def tearDown(self):
ray.shutdown()
self.cluster.shutdown()
def test_continue_training_on_failure(self):
# We tolerate failing workers and pause training
config = (
PPOConfig()
.rollouts(
num_rollout_workers=6,
recreate_failed_workers=True,
validate_workers_after_construction=True,
)
.training(
train_batch_size=300,
)
)
ppo = PPO(config=config, env="CartPole-v1")
# One step with all nodes up, enough to satisfy resource requirements
ppo.step()
self.assertEqual(ppo.workers.num_healthy_remote_workers(), 6)
self.assertEqual(ppo.workers.num_remote_workers(), 6)
# Remove the first non-head node.
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
self.cluster.remove_node(node_to_kill)
# step() should continue with 4 rollout workers.
ppo.step()
self.assertEqual(ppo.workers.num_healthy_remote_workers(), 4)
self.assertEqual(ppo.workers.num_remote_workers(), 6)
# node comes back immediately.
self.cluster.add_node(
redis_port=None,
num_redis_shards=None,
num_cpus=2,
num_gpus=0,
object_store_memory=object_store_memory,
redis_max_memory=redis_max_memory,
dashboard_host="0.0.0.0",
)
# Now, let's wait for Ray to restart all the RolloutWorker actors.
while True:
states = [
a["state"] == "ALIVE"
for a in list_actors()
if a["class_name"] == "RolloutWorker"
]
if all(states):
break
# Otherwise, wait a bit.
time.sleep(1)
# This step should continue with 4 workers, but by the end
# of weight syncing, the 2 recovered rollout workers should
# be back.
ppo.step()
# Workers should be back up, everything back to normal.
self.assertEqual(ppo.workers.num_healthy_remote_workers(), 6)
self.assertEqual(ppo.workers.num_remote_workers(), 6)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))