Skip to content

Commit

Permalink
[Job submission] Basic job submission structure (ray-project#15103)
Browse files Browse the repository at this point in the history
  • Loading branch information
fyrestone authored May 12, 2021
1 parent fcf56fb commit 56c3094
Show file tree
Hide file tree
Showing 12 changed files with 711 additions and 7 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,7 @@ filegroup(
"//src/ray/protobuf:core_worker_py_proto",
"//src/ray/protobuf:gcs_py_proto",
"//src/ray/protobuf:gcs_service_py_proto",
"//src/ray/protobuf:job_agent_py_proto",
"//src/ray/protobuf:node_manager_py_proto",
"//src/ray/protobuf:ray_client_py_proto",
"//src/ray/protobuf:reporter_py_proto",
Expand Down
315 changes: 315 additions & 0 deletions dashboard/modules/job/job_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
import asyncio
import json
import logging
import os.path
import itertools
import subprocess
import sys
import secrets
import uuid
import traceback
from abc import abstractmethod
from typing import Union, Any

import ray.new_dashboard.utils as dashboard_utils
from ray.new_dashboard.utils import create_task
from ray.new_dashboard.modules.job import job_consts
from ray.new_dashboard.modules.job.job_description import JobDescription
from ray.core.generated import job_agent_pb2
from ray.core.generated import job_agent_pb2_grpc
from ray.core.generated import agent_manager_pb2

logger = logging.getLogger(__name__)


class JobInfo(JobDescription):
# TODO(fyrestone): We should use job id instead of unique id.
unique_id: str
# The temp directory.
temp_dir: str
# The log directory.
log_dir: str
# The driver process instance.
driver: Union[None, asyncio.subprocess.Process]

def __init__(self, **data: Any):
super().__init__(**data)
# Support json values for env.
self.env = {
k: v if isinstance(v, str) else json.dumps(v)
for k, v in self.env.items()
}


class JobProcessor:
"""Wraps the job info and provides common utils to download packages,
start drivers, etc.
Args:
job_info (JobInfo): The job info.
"""
_cmd_index_gen = itertools.count(1)

def __init__(self, job_info):
assert isinstance(job_info, JobInfo)
self._job_info = job_info

async def _download_package(self, http_session, url, filename):
unique_id = self._job_info.unique_id
cmd_index = next(self._cmd_index_gen)
logger.info("[%s] Start download[%s] %s to %s", unique_id, cmd_index,
url, filename)
async with http_session.get(url, ssl=False) as response:
with open(filename, "wb") as f:
while True:
chunk = await response.content.read(
job_consts.DOWNLOAD_BUFFER_SIZE)
if not chunk:
break
f.write(chunk)
logger.info("[%s] Finished download[%s] %s to %s", unique_id,
cmd_index, url, filename)

async def _unpack_package(self, filename, path):
code = f"import shutil; " \
f"shutil.unpack_archive({repr(filename)}, {repr(path)})"
unzip_cmd = [self._get_current_python(), "-c", code]
await self._check_output_cmd(unzip_cmd)

async def _check_output_cmd(self, cmd):
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE)
unique_id = self._job_info.unique_id
cmd_index = next(self._cmd_index_gen)
proc.cmd_index = cmd_index
logger.info("[%s] Run cmd[%s] %s", unique_id, cmd_index, repr(cmd))
stdout, stderr = await proc.communicate()
stdout = stdout.decode("utf-8")
logger.info("[%s] Output of cmd[%s]: %s", unique_id, cmd_index, stdout)
if proc.returncode != 0:
stderr = stderr.decode("utf-8")
logger.error("[%s] Error of cmd[%s]: %s", unique_id, cmd_index,
stderr)
raise subprocess.CalledProcessError(
proc.returncode, cmd, output=stdout, stderr=stderr)
return stdout

async def _start_driver(self, cmd, stdout, stderr, env):
unique_id = self._job_info.unique_id
job_package_dir = job_consts.JOB_UNPACK_DIR.format(
temp_dir=self._job_info.temp_dir, unique_id=unique_id)
cmd_str = subprocess.list2cmdline(cmd)
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=stdout,
stderr=stderr,
env={
**os.environ,
**env,
},
cwd=job_package_dir,
)
logger.info("[%s] Start driver cmd %s with pid %s", unique_id,
repr(cmd_str), proc.pid)
return proc

@staticmethod
def _get_current_python():
return sys.executable

@staticmethod
def _new_log_files(log_dir, filename):
if log_dir is None:
return None, None
stdout = open(
os.path.join(log_dir, filename + ".out"), "a", buffering=1)
stderr = open(
os.path.join(log_dir, filename + ".err"), "a", buffering=1)
return stdout, stderr

@abstractmethod
async def run(self):
pass


class DownloadPackage(JobProcessor):
""" Download the job package.
Args:
job_info (JobInfo): The job info.
http_session (aiohttp.ClientSession): The client session.
"""

def __init__(self, job_info, http_session):
super().__init__(job_info)
self._http_session = http_session

async def run(self):
temp_dir = self._job_info.temp_dir
unique_id = self._job_info.unique_id
filename = job_consts.DOWNLOAD_PACKAGE_FILE.format(
temp_dir=temp_dir, unique_id=unique_id)
unpack_dir = job_consts.JOB_UNPACK_DIR.format(
temp_dir=temp_dir, unique_id=unique_id)
url = self._job_info.runtime_env.working_dir
await self._download_package(self._http_session, url, filename)
await self._unpack_package(filename, unpack_dir)


class StartPythonDriver(JobProcessor):
""" Start the driver for Python job.
Args:
job_info (JobInfo): The job info.
redis_address (tuple): The (ip, port) of redis.
redis_password (str): The password of redis.
"""

_template = """import sys
sys.path.append({import_path})
import ray
from ray._private.utils import hex_to_binary
ray.init(ignore_reinit_error=True,
address={redis_address},
_redis_password={redis_password},
job_config=ray.job_config.JobConfig({job_config_args}),
)
import {driver_entry}
{driver_entry}.main({driver_args})
# If the driver exits normally, we invoke Ray.shutdown() again
# here, in case the user code forgot to invoke it.
ray.shutdown()
"""

def __init__(self, job_info, redis_address, redis_password):
super().__init__(job_info)
self._redis_address = redis_address
self._redis_password = redis_password

def _gen_driver_code(self):
temp_dir = self._job_info.temp_dir
unique_id = self._job_info.unique_id
job_package_dir = job_consts.JOB_UNPACK_DIR.format(
temp_dir=temp_dir, unique_id=unique_id)
driver_entry_file = job_consts.JOB_DRIVER_ENTRY_FILE.format(
temp_dir=temp_dir, unique_id=unique_id, uuid=uuid.uuid4())
ip, port = self._redis_address

# Per job config
job_config_items = {
"worker_env": self._job_info.env,
"code_search_path": [job_package_dir],
}

job_config_args = ", ".join(f"{key}={repr(value)}"
for key, value in job_config_items.items()
if value is not None)
driver_args = ", ".join([repr(x) for x in self._job_info.driver_args])
driver_code = self._template.format(
job_config_args=job_config_args,
import_path=repr(job_package_dir),
redis_address=repr(ip + ":" + str(port)),
redis_password=repr(self._redis_password),
driver_entry=self._job_info.driver_entry,
driver_args=driver_args)
with open(driver_entry_file, "w") as fp:
fp.write(driver_code)
return driver_entry_file

async def run(self):
python = self._get_current_python()
driver_file = self._gen_driver_code()
driver_cmd = [python, "-u", driver_file]
stdout_file, stderr_file = self._new_log_files(
self._job_info.log_dir, f"driver-{self._job_info.unique_id}")
return await self._start_driver(driver_cmd, stdout_file, stderr_file,
self._job_info.env)


class JobAgent(dashboard_utils.DashboardAgentModule,
job_agent_pb2_grpc.JobAgentServiceServicer):
""" The JobAgentService defined in job_agent.proto for initializing /
cleaning job environments.
"""

async def InitializeJobEnv(self, request, context):
# TODO(fyrestone): Handle duplicated InitializeJobEnv requests
# when initializing job environment.
# TODO(fyrestone): Support reinitialize job environment.

# TODO(fyrestone): Use job id instead of unique id.
unique_id = secrets.token_hex(6)

# Parse the job description from the request.
try:
job_description_data = json.loads(request.job_description)
job_info = JobInfo(
unique_id=unique_id,
temp_dir=self._dashboard_agent.temp_dir,
log_dir=self._dashboard_agent.log_dir,
**job_description_data)
except json.JSONDecodeError as ex:
error_message = str(ex)
error_message += f", job_payload:\n{request.job_description}"
logger.error("[%s] Initialize job environment failed, %s.",
unique_id, error_message)
return job_agent_pb2.InitializeJobEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message=error_message)
except Exception as ex:
logger.exception(ex)
return job_agent_pb2.InitializeJobEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message=traceback.format_exc())

async def _initialize_job_env():
os.makedirs(
job_consts.JOB_DIR.format(
temp_dir=job_info.temp_dir, unique_id=unique_id),
exist_ok=True)
# Download the job package.
await DownloadPackage(job_info,
self._dashboard_agent.http_session).run()
# Start the driver.
logger.info("[%s] Starting driver.", unique_id)
language = job_info.language
if language == job_consts.PYTHON:
driver = await StartPythonDriver(
job_info, self._dashboard_agent.redis_address,
self._dashboard_agent.redis_password).run()
else:
raise Exception(f"Unsupported language type: {language}")
job_info.driver = driver

initialize_task = create_task(_initialize_job_env())

try:
await initialize_task
except asyncio.CancelledError:
logger.error("[%s] Initialize job environment has been cancelled.",
unique_id)
return job_agent_pb2.InitializeJobEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message="InitializeJobEnv has been cancelled, "
"did you call CleanJobEnv?")
except Exception as ex:
logger.exception(ex)
return job_agent_pb2.InitializeJobEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message=traceback.format_exc())

driver_pid = 0
if job_info.driver:
driver_pid = job_info.driver.pid

logger.info(
"[%s] Job environment initialized, "
"the driver (pid=%s) started.", unique_id, driver_pid)
return job_agent_pb2.InitializeJobEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
driver_pid=driver_pid)

async def run(self, server):
job_agent_pb2_grpc.add_JobAgentServiceServicer_to_server(self, server)
16 changes: 16 additions & 0 deletions dashboard/modules/job/job_consts.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,18 @@
import os
from ray.core.generated import common_pb2

# Job agent consts
# TODO(fyrestone): We should use job id instead of unique_id.
JOB_DIR = "{temp_dir}/job/{unique_id}/"
JOB_UNPACK_DIR = os.path.join(JOB_DIR, "package")
JOB_DRIVER_ENTRY_FILE = os.path.join(JOB_DIR, "driver-{uuid}.py")
# Downloader constants
DOWNLOAD_BUFFER_SIZE = 10 * 1024 * 1024 # 10MB
DOWNLOAD_PACKAGE_FILE = os.path.join(JOB_DIR, "package.zip")
# Redis key
JOB_CHANNEL = "JOB"
RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS = 2
# Languages
PYTHON = common_pb2.Language.Name(common_pb2.Language.PYTHON)
JAVA = common_pb2.Language.Name(common_pb2.Language.JAVA)
CPP = common_pb2.Language.Name(common_pb2.Language.CPP)
49 changes: 49 additions & 0 deletions dashboard/modules/job/job_description.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from enum import Enum
from typing import Dict, Any
from pydantic import BaseModel as PydanticBaseModel, Extra
from ray.core.generated import common_pb2


class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
extra = Extra.forbid


class Language(str, Enum):
PYTHON = common_pb2.Language.Name(common_pb2.PYTHON)
JAVA = common_pb2.Language.Name(common_pb2.JAVA)
CPP = common_pb2.Language.Name(common_pb2.CPP)


class RuntimeEnv(BaseModel):
# The url to download the job package archive. The archive format is
# one of “zip”, “tar”, “gztar”, “bztar”, or “xztar”. Please refer to
# https://docs.python.org/3/library/shutil.html#shutil.unpack_archive
working_dir: str


class JobDescription(BaseModel):
# The job driver language, this field determines how to start the
# driver. The value is one of the names of enum Language defined in
# common.proto, e.g. PYTHON
language: Language
# The runtime_env (RuntimeEnvDict) for the job config.
runtime_env: RuntimeEnv
# The entry to start the driver.
# PYTHON:
# - The basename of driver filename without extension in the job
# package archive.
# JAVA:
# - The driver class full name in the job package archive.
driver_entry: str
# The driver arguments in list.
# PYTHON:
# - The arguments to pass to the main() function in driver entry.
# e.g. [1, False, 3.14, "abc"]
# JAVA:
# - The arguments to pass to the driver command line.
# e.g. ["-custom-arg", "abc"]
driver_args: list = []
# The environment vars to pass to job config, type of keys should be str.
env: Dict[str, Any] = {}
Loading

0 comments on commit 56c3094

Please sign in to comment.