forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjob_manager.py
106 lines (88 loc) · 3.63 KB
/
job_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import time
from typing import Dict, Tuple, TYPE_CHECKING
if TYPE_CHECKING:
from ray.job_submission import JobSubmissionClient, JobStatus # noqa: F401
from ray_release.logger import logger
from ray_release.util import ANYSCALE_HOST
from ray_release.cluster_manager.cluster_manager import ClusterManager
from ray_release.exception import CommandTimeout
from ray_release.util import exponential_backoff_retry
class JobManager:
def __init__(self, cluster_manager: ClusterManager):
self.job_id_pool: Dict[int, str] = dict()
self.start_time: Dict[int, float] = dict()
self.counter = 0
self.cluster_manager = cluster_manager
self.job_client = None
self.last_job_id = None
def _get_job_client(self) -> "JobSubmissionClient":
from ray.job_submission import JobSubmissionClient # noqa: F811
if not self.job_client:
self.job_client = JobSubmissionClient(
self.cluster_manager.get_cluster_address()
)
return self.job_client
def _run_job(self, cmd_to_run, env_vars) -> int:
self.counter += 1
command_id = self.counter
env = os.environ.copy()
env["RAY_ADDRESS"] = self.cluster_manager.get_cluster_address()
env.setdefault("ANYSCALE_HOST", ANYSCALE_HOST)
full_cmd = " ".join(f"{k}={v}" for k, v in env_vars.items()) + " " + cmd_to_run
logger.info(f"Executing {cmd_to_run} with {env_vars} via ray job submit")
job_client = self._get_job_client()
job_id = job_client.submit_job(
# Entrypoint shell command to execute
entrypoint=full_cmd,
)
self.last_job_id = job_id
self.job_id_pool[command_id] = job_id
self.start_time[command_id] = time.time()
return command_id
def _get_job_status_with_retry(self, command_id):
job_client = self._get_job_client()
return exponential_backoff_retry(
lambda: job_client.get_job_status(self.job_id_pool[command_id]),
retry_exceptions=Exception,
initial_retry_delay_s=1,
max_retries=3,
)
def _wait_job(self, command_id: int, timeout: int):
from ray.job_submission import JobStatus # noqa: F811
start_time = time.monotonic()
timeout_at = start_time + timeout
next_status = start_time + 30
while True:
now = time.monotonic()
if now >= timeout_at:
raise CommandTimeout(
f"Cluster command timed out after {timeout} seconds."
)
if now >= next_status:
logger.info(
f"... command still running ..."
f"({int(now - start_time)} seconds) ..."
)
next_status += 30
status = self._get_job_status_with_retry(command_id)
if status in {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}:
break
time.sleep(1)
status = self._get_job_status_with_retry(command_id)
# TODO(sang): Propagate JobInfo.error_type
if status == JobStatus.SUCCEEDED:
retcode = 0
else:
retcode = -1
duration = time.time() - self.start_time[command_id]
return retcode, duration
def run_and_wait(
self, cmd_to_run, env_vars, timeout: int = 120
) -> Tuple[int, float]:
cid = self._run_job(cmd_to_run, env_vars)
return self._wait_job(cid, timeout)
def get_last_logs(self):
# return None
job_client = self._get_job_client()
return job_client.get_job_logs(self.last_job_id)