forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Job submission] Basic job submission structure (ray-project#15103)
- Loading branch information
Showing
12 changed files
with
711 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,315 @@ | ||
import asyncio | ||
import json | ||
import logging | ||
import os.path | ||
import itertools | ||
import subprocess | ||
import sys | ||
import secrets | ||
import uuid | ||
import traceback | ||
from abc import abstractmethod | ||
from typing import Union, Any | ||
|
||
import ray.new_dashboard.utils as dashboard_utils | ||
from ray.new_dashboard.utils import create_task | ||
from ray.new_dashboard.modules.job import job_consts | ||
from ray.new_dashboard.modules.job.job_description import JobDescription | ||
from ray.core.generated import job_agent_pb2 | ||
from ray.core.generated import job_agent_pb2_grpc | ||
from ray.core.generated import agent_manager_pb2 | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class JobInfo(JobDescription): | ||
# TODO(fyrestone): We should use job id instead of unique id. | ||
unique_id: str | ||
# The temp directory. | ||
temp_dir: str | ||
# The log directory. | ||
log_dir: str | ||
# The driver process instance. | ||
driver: Union[None, asyncio.subprocess.Process] | ||
|
||
def __init__(self, **data: Any): | ||
super().__init__(**data) | ||
# Support json values for env. | ||
self.env = { | ||
k: v if isinstance(v, str) else json.dumps(v) | ||
for k, v in self.env.items() | ||
} | ||
|
||
|
||
class JobProcessor: | ||
"""Wraps the job info and provides common utils to download packages, | ||
start drivers, etc. | ||
Args: | ||
job_info (JobInfo): The job info. | ||
""" | ||
_cmd_index_gen = itertools.count(1) | ||
|
||
def __init__(self, job_info): | ||
assert isinstance(job_info, JobInfo) | ||
self._job_info = job_info | ||
|
||
async def _download_package(self, http_session, url, filename): | ||
unique_id = self._job_info.unique_id | ||
cmd_index = next(self._cmd_index_gen) | ||
logger.info("[%s] Start download[%s] %s to %s", unique_id, cmd_index, | ||
url, filename) | ||
async with http_session.get(url, ssl=False) as response: | ||
with open(filename, "wb") as f: | ||
while True: | ||
chunk = await response.content.read( | ||
job_consts.DOWNLOAD_BUFFER_SIZE) | ||
if not chunk: | ||
break | ||
f.write(chunk) | ||
logger.info("[%s] Finished download[%s] %s to %s", unique_id, | ||
cmd_index, url, filename) | ||
|
||
async def _unpack_package(self, filename, path): | ||
code = f"import shutil; " \ | ||
f"shutil.unpack_archive({repr(filename)}, {repr(path)})" | ||
unzip_cmd = [self._get_current_python(), "-c", code] | ||
await self._check_output_cmd(unzip_cmd) | ||
|
||
async def _check_output_cmd(self, cmd): | ||
proc = await asyncio.create_subprocess_exec( | ||
*cmd, | ||
stdout=asyncio.subprocess.PIPE, | ||
stderr=asyncio.subprocess.PIPE) | ||
unique_id = self._job_info.unique_id | ||
cmd_index = next(self._cmd_index_gen) | ||
proc.cmd_index = cmd_index | ||
logger.info("[%s] Run cmd[%s] %s", unique_id, cmd_index, repr(cmd)) | ||
stdout, stderr = await proc.communicate() | ||
stdout = stdout.decode("utf-8") | ||
logger.info("[%s] Output of cmd[%s]: %s", unique_id, cmd_index, stdout) | ||
if proc.returncode != 0: | ||
stderr = stderr.decode("utf-8") | ||
logger.error("[%s] Error of cmd[%s]: %s", unique_id, cmd_index, | ||
stderr) | ||
raise subprocess.CalledProcessError( | ||
proc.returncode, cmd, output=stdout, stderr=stderr) | ||
return stdout | ||
|
||
async def _start_driver(self, cmd, stdout, stderr, env): | ||
unique_id = self._job_info.unique_id | ||
job_package_dir = job_consts.JOB_UNPACK_DIR.format( | ||
temp_dir=self._job_info.temp_dir, unique_id=unique_id) | ||
cmd_str = subprocess.list2cmdline(cmd) | ||
proc = await asyncio.create_subprocess_exec( | ||
*cmd, | ||
stdout=stdout, | ||
stderr=stderr, | ||
env={ | ||
**os.environ, | ||
**env, | ||
}, | ||
cwd=job_package_dir, | ||
) | ||
logger.info("[%s] Start driver cmd %s with pid %s", unique_id, | ||
repr(cmd_str), proc.pid) | ||
return proc | ||
|
||
@staticmethod | ||
def _get_current_python(): | ||
return sys.executable | ||
|
||
@staticmethod | ||
def _new_log_files(log_dir, filename): | ||
if log_dir is None: | ||
return None, None | ||
stdout = open( | ||
os.path.join(log_dir, filename + ".out"), "a", buffering=1) | ||
stderr = open( | ||
os.path.join(log_dir, filename + ".err"), "a", buffering=1) | ||
return stdout, stderr | ||
|
||
@abstractmethod | ||
async def run(self): | ||
pass | ||
|
||
|
||
class DownloadPackage(JobProcessor): | ||
""" Download the job package. | ||
Args: | ||
job_info (JobInfo): The job info. | ||
http_session (aiohttp.ClientSession): The client session. | ||
""" | ||
|
||
def __init__(self, job_info, http_session): | ||
super().__init__(job_info) | ||
self._http_session = http_session | ||
|
||
async def run(self): | ||
temp_dir = self._job_info.temp_dir | ||
unique_id = self._job_info.unique_id | ||
filename = job_consts.DOWNLOAD_PACKAGE_FILE.format( | ||
temp_dir=temp_dir, unique_id=unique_id) | ||
unpack_dir = job_consts.JOB_UNPACK_DIR.format( | ||
temp_dir=temp_dir, unique_id=unique_id) | ||
url = self._job_info.runtime_env.working_dir | ||
await self._download_package(self._http_session, url, filename) | ||
await self._unpack_package(filename, unpack_dir) | ||
|
||
|
||
class StartPythonDriver(JobProcessor): | ||
""" Start the driver for Python job. | ||
Args: | ||
job_info (JobInfo): The job info. | ||
redis_address (tuple): The (ip, port) of redis. | ||
redis_password (str): The password of redis. | ||
""" | ||
|
||
_template = """import sys | ||
sys.path.append({import_path}) | ||
import ray | ||
from ray._private.utils import hex_to_binary | ||
ray.init(ignore_reinit_error=True, | ||
address={redis_address}, | ||
_redis_password={redis_password}, | ||
job_config=ray.job_config.JobConfig({job_config_args}), | ||
) | ||
import {driver_entry} | ||
{driver_entry}.main({driver_args}) | ||
# If the driver exits normally, we invoke Ray.shutdown() again | ||
# here, in case the user code forgot to invoke it. | ||
ray.shutdown() | ||
""" | ||
|
||
def __init__(self, job_info, redis_address, redis_password): | ||
super().__init__(job_info) | ||
self._redis_address = redis_address | ||
self._redis_password = redis_password | ||
|
||
def _gen_driver_code(self): | ||
temp_dir = self._job_info.temp_dir | ||
unique_id = self._job_info.unique_id | ||
job_package_dir = job_consts.JOB_UNPACK_DIR.format( | ||
temp_dir=temp_dir, unique_id=unique_id) | ||
driver_entry_file = job_consts.JOB_DRIVER_ENTRY_FILE.format( | ||
temp_dir=temp_dir, unique_id=unique_id, uuid=uuid.uuid4()) | ||
ip, port = self._redis_address | ||
|
||
# Per job config | ||
job_config_items = { | ||
"worker_env": self._job_info.env, | ||
"code_search_path": [job_package_dir], | ||
} | ||
|
||
job_config_args = ", ".join(f"{key}={repr(value)}" | ||
for key, value in job_config_items.items() | ||
if value is not None) | ||
driver_args = ", ".join([repr(x) for x in self._job_info.driver_args]) | ||
driver_code = self._template.format( | ||
job_config_args=job_config_args, | ||
import_path=repr(job_package_dir), | ||
redis_address=repr(ip + ":" + str(port)), | ||
redis_password=repr(self._redis_password), | ||
driver_entry=self._job_info.driver_entry, | ||
driver_args=driver_args) | ||
with open(driver_entry_file, "w") as fp: | ||
fp.write(driver_code) | ||
return driver_entry_file | ||
|
||
async def run(self): | ||
python = self._get_current_python() | ||
driver_file = self._gen_driver_code() | ||
driver_cmd = [python, "-u", driver_file] | ||
stdout_file, stderr_file = self._new_log_files( | ||
self._job_info.log_dir, f"driver-{self._job_info.unique_id}") | ||
return await self._start_driver(driver_cmd, stdout_file, stderr_file, | ||
self._job_info.env) | ||
|
||
|
||
class JobAgent(dashboard_utils.DashboardAgentModule, | ||
job_agent_pb2_grpc.JobAgentServiceServicer): | ||
""" The JobAgentService defined in job_agent.proto for initializing / | ||
cleaning job environments. | ||
""" | ||
|
||
async def InitializeJobEnv(self, request, context): | ||
# TODO(fyrestone): Handle duplicated InitializeJobEnv requests | ||
# when initializing job environment. | ||
# TODO(fyrestone): Support reinitialize job environment. | ||
|
||
# TODO(fyrestone): Use job id instead of unique id. | ||
unique_id = secrets.token_hex(6) | ||
|
||
# Parse the job description from the request. | ||
try: | ||
job_description_data = json.loads(request.job_description) | ||
job_info = JobInfo( | ||
unique_id=unique_id, | ||
temp_dir=self._dashboard_agent.temp_dir, | ||
log_dir=self._dashboard_agent.log_dir, | ||
**job_description_data) | ||
except json.JSONDecodeError as ex: | ||
error_message = str(ex) | ||
error_message += f", job_payload:\n{request.job_description}" | ||
logger.error("[%s] Initialize job environment failed, %s.", | ||
unique_id, error_message) | ||
return job_agent_pb2.InitializeJobEnvReply( | ||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED, | ||
error_message=error_message) | ||
except Exception as ex: | ||
logger.exception(ex) | ||
return job_agent_pb2.InitializeJobEnvReply( | ||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED, | ||
error_message=traceback.format_exc()) | ||
|
||
async def _initialize_job_env(): | ||
os.makedirs( | ||
job_consts.JOB_DIR.format( | ||
temp_dir=job_info.temp_dir, unique_id=unique_id), | ||
exist_ok=True) | ||
# Download the job package. | ||
await DownloadPackage(job_info, | ||
self._dashboard_agent.http_session).run() | ||
# Start the driver. | ||
logger.info("[%s] Starting driver.", unique_id) | ||
language = job_info.language | ||
if language == job_consts.PYTHON: | ||
driver = await StartPythonDriver( | ||
job_info, self._dashboard_agent.redis_address, | ||
self._dashboard_agent.redis_password).run() | ||
else: | ||
raise Exception(f"Unsupported language type: {language}") | ||
job_info.driver = driver | ||
|
||
initialize_task = create_task(_initialize_job_env()) | ||
|
||
try: | ||
await initialize_task | ||
except asyncio.CancelledError: | ||
logger.error("[%s] Initialize job environment has been cancelled.", | ||
unique_id) | ||
return job_agent_pb2.InitializeJobEnvReply( | ||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED, | ||
error_message="InitializeJobEnv has been cancelled, " | ||
"did you call CleanJobEnv?") | ||
except Exception as ex: | ||
logger.exception(ex) | ||
return job_agent_pb2.InitializeJobEnvReply( | ||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED, | ||
error_message=traceback.format_exc()) | ||
|
||
driver_pid = 0 | ||
if job_info.driver: | ||
driver_pid = job_info.driver.pid | ||
|
||
logger.info( | ||
"[%s] Job environment initialized, " | ||
"the driver (pid=%s) started.", unique_id, driver_pid) | ||
return job_agent_pb2.InitializeJobEnvReply( | ||
status=agent_manager_pb2.AGENT_RPC_STATUS_OK, | ||
driver_pid=driver_pid) | ||
|
||
async def run(self, server): | ||
job_agent_pb2_grpc.add_JobAgentServiceServicer_to_server(self, server) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,18 @@ | ||
import os | ||
from ray.core.generated import common_pb2 | ||
|
||
# Job agent consts | ||
# TODO(fyrestone): We should use job id instead of unique_id. | ||
JOB_DIR = "{temp_dir}/job/{unique_id}/" | ||
JOB_UNPACK_DIR = os.path.join(JOB_DIR, "package") | ||
JOB_DRIVER_ENTRY_FILE = os.path.join(JOB_DIR, "driver-{uuid}.py") | ||
# Downloader constants | ||
DOWNLOAD_BUFFER_SIZE = 10 * 1024 * 1024 # 10MB | ||
DOWNLOAD_PACKAGE_FILE = os.path.join(JOB_DIR, "package.zip") | ||
# Redis key | ||
JOB_CHANNEL = "JOB" | ||
RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS = 2 | ||
# Languages | ||
PYTHON = common_pb2.Language.Name(common_pb2.Language.PYTHON) | ||
JAVA = common_pb2.Language.Name(common_pb2.Language.JAVA) | ||
CPP = common_pb2.Language.Name(common_pb2.Language.CPP) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from enum import Enum | ||
from typing import Dict, Any | ||
from pydantic import BaseModel as PydanticBaseModel, Extra | ||
from ray.core.generated import common_pb2 | ||
|
||
|
||
class BaseModel(PydanticBaseModel): | ||
class Config: | ||
arbitrary_types_allowed = True | ||
extra = Extra.forbid | ||
|
||
|
||
class Language(str, Enum): | ||
PYTHON = common_pb2.Language.Name(common_pb2.PYTHON) | ||
JAVA = common_pb2.Language.Name(common_pb2.JAVA) | ||
CPP = common_pb2.Language.Name(common_pb2.CPP) | ||
|
||
|
||
class RuntimeEnv(BaseModel): | ||
# The url to download the job package archive. The archive format is | ||
# one of “zip”, “tar”, “gztar”, “bztar”, or “xztar”. Please refer to | ||
# https://docs.python.org/3/library/shutil.html#shutil.unpack_archive | ||
working_dir: str | ||
|
||
|
||
class JobDescription(BaseModel): | ||
# The job driver language, this field determines how to start the | ||
# driver. The value is one of the names of enum Language defined in | ||
# common.proto, e.g. PYTHON | ||
language: Language | ||
# The runtime_env (RuntimeEnvDict) for the job config. | ||
runtime_env: RuntimeEnv | ||
# The entry to start the driver. | ||
# PYTHON: | ||
# - The basename of driver filename without extension in the job | ||
# package archive. | ||
# JAVA: | ||
# - The driver class full name in the job package archive. | ||
driver_entry: str | ||
# The driver arguments in list. | ||
# PYTHON: | ||
# - The arguments to pass to the main() function in driver entry. | ||
# e.g. [1, False, 3.14, "abc"] | ||
# JAVA: | ||
# - The arguments to pass to the driver command line. | ||
# e.g. ["-custom-arg", "abc"] | ||
driver_args: list = [] | ||
# The environment vars to pass to job config, type of keys should be str. | ||
env: Dict[str, Any] = {} |
Oops, something went wrong.