Skip to content

Commit

Permalink
[Job Submission][refactor 2/N] introduce job agent (ray-project#28203)
Browse files Browse the repository at this point in the history
  • Loading branch information
Catch-Bull authored Sep 3, 2022
1 parent a31be7c commit ce70b8b
Show file tree
Hide file tree
Showing 6 changed files with 661 additions and 135 deletions.
97 changes: 97 additions & 0 deletions dashboard/modules/job/job_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import aiohttp
from aiohttp.web import Request, Response
import dataclasses
import json
import logging
import traceback

import ray.dashboard.optional_utils as optional_utils
import ray.dashboard.utils as dashboard_utils
from ray.dashboard.modules.job.common import (
JobSubmitRequest,
JobSubmitResponse,
)
from ray.dashboard.modules.job.job_manager import JobManager
from ray.dashboard.modules.job.utils import parse_and_validate_request, find_job_by_ids


routes = optional_utils.ClassMethodRouteTable
logger = logging.getLogger(__name__)


class JobAgent(dashboard_utils.DashboardAgentModule):
def __init__(self, dashboard_agent):
super().__init__(dashboard_agent)
self._job_manager = None
self._gcs_job_info_stub = None

@routes.post("/api/job_agent/jobs/")
@optional_utils.init_ray_and_catch_exceptions()
async def submit_job(self, req: Request) -> Response:
result = await parse_and_validate_request(req, JobSubmitRequest)
# Request parsing failed, returned with Response object.
if isinstance(result, Response):
return result
else:
submit_request = result

request_submission_id = submit_request.submission_id or submit_request.job_id
try:
submission_id = await self.get_job_manager().submit_job(
entrypoint=submit_request.entrypoint,
submission_id=request_submission_id,
runtime_env=submit_request.runtime_env,
metadata=submit_request.metadata,
_driver_on_current_node=False,
)

resp = JobSubmitResponse(job_id=submission_id, submission_id=submission_id)
except (TypeError, ValueError):
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPBadRequest.status_code,
)
except Exception:
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code,
)

return Response(
text=json.dumps(dataclasses.asdict(resp)),
content_type="application/json",
status=aiohttp.web.HTTPOk.status_code,
)

@routes.get("/api/job_agent/jobs/{job_or_submission_id}")
@optional_utils.init_ray_and_catch_exceptions()
async def get_job_info(self, req: Request) -> Response:
job_or_submission_id = req.match_info["job_or_submission_id"]

job = await find_job_by_ids(
self._dashboard_agent.gcs_aio_client,
self.get_job_manager(),
job_or_submission_id,
)
if not job:
return Response(
text=f"Job {job_or_submission_id} does not exist",
status=aiohttp.web.HTTPNotFound.status_code,
)

return Response(
text=json.dumps(job.dict()),
content_type="application/json",
)

def get_job_manager(self):
if not self._job_manager:
self._job_manager = JobManager(self._dashboard_agent.gcs_aio_client)
return self._job_manager

async def run(self, server):
pass

@staticmethod
def is_minimal_module():
return False
201 changes: 72 additions & 129 deletions dashboard/modules/job/job_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,109 +2,102 @@
import json
import logging
import traceback
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

import aiohttp.web
from aiohttp.web import Request, Response
from aiohttp.client import ClientResponse

import ray
from ray._private import ray_constants
import ray.dashboard.optional_utils as optional_utils
import ray.dashboard.utils as dashboard_utils
from ray._private.runtime_env.packaging import (
package_exists,
pin_runtime_env_uri,
upload_package_to_gcs,
)
from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc
from ray.dashboard.modules.job.common import (
http_uri_components_to_uri,
JobStatus,
JobSubmitRequest,
JobSubmitResponse,
JobStopResponse,
JobLogsResponse,
validate_request_type,
JOB_ID_METADATA_KEY,
)
from ray.dashboard.modules.job.pydantic_models import (
DriverInfo,
JobDetails,
JobType,
)
from ray.dashboard.modules.job.utils import (
parse_and_validate_request,
get_driver_jobs,
find_job_by_ids,
)
from ray.dashboard.modules.version import (
CURRENT_VERSION,
VersionResponse,
)
from ray.dashboard.modules.job.job_manager import JobManager
from ray.runtime_env import RuntimeEnv

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

routes = optional_utils.ClassMethodRouteTable


class JobAgentSubmissionClient:
"""A local client for submitting and interacting with jobs on a specific node
in the remote cluster.
Submits requests over HTTP to the job agent on the specific node using the REST API.
"""

def __init__(
self,
dashboard_agent_address: str,
):
self._address = dashboard_agent_address
self._session = aiohttp.ClientSession()

async def _raise_error(self, r: ClientResponse):
status = r.status
error_text = await r.text()
raise RuntimeError(f"Request failed with status code {status}: {error_text}.")

async def submit_job_internal(self, req: JobSubmitRequest) -> JobSubmitResponse:

logger.debug(f"Submitting job with submission_id={req.submission_id}.")

async with self._session.post(
self._address + "/api/job_agent/jobs/", json=dataclasses.asdict(req)
) as resp:

if resp.status == 200:
result_json = await resp.json()
return JobSubmitResponse(**result_json)
else:
await self._raise_error(resp)

async def get_job_info(self, job_id: str) -> JobDetails:
async with self._session.get(
self._address + f"/api/job_agent/jobs/{job_id}"
) as resp:
if resp.status == 200:
result_json = await resp.json()
return JobDetails(**result_json)
else:
await self._raise_error(resp)

async def close(self, ignore_error=True):
try:
await self._session.close()
except Exception:
if not ignore_error:
raise


class JobHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._dashboard_head = dashboard_head
self._job_manager = None
self._gcs_job_info_stub = None

async def _parse_and_validate_request(
self, req: Request, request_type: dataclass
) -> Any:
"""Parse request and cast to request type. If parsing failed, return a
Response object with status 400 and stacktrace instead.
"""
try:
return validate_request_type(await req.json(), request_type)
except Exception as e:
logger.info(f"Got invalid request type: {e}")
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPBadRequest.status_code,
)

async def find_job_by_ids(self, job_or_submission_id: str) -> Optional[JobDetails]:
"""
Attempts to find the job with a given submission_id or job id.
"""
# First try to find by job_id
driver_jobs, submission_job_drivers = await self._get_driver_jobs()
job = driver_jobs.get(job_or_submission_id)
if job:
return job
# Try to find a driver with the given id
submission_id = next(
(
id
for id, driver in submission_job_drivers.items()
if driver.id == job_or_submission_id
),
None,
)

if not submission_id:
# If we didn't find a driver with the given id,
# then lets try to search for a submission with given id
submission_id = job_or_submission_id

job_info = await self._job_manager.get_job_info(submission_id)
if job_info:
driver = submission_job_drivers.get(submission_id)
job = JobDetails(
**dataclasses.asdict(job_info),
submission_id=submission_id,
job_id=driver.id if driver else None,
driver_info=driver,
type=JobType.SUBMISSION,
)
return job

return None

@routes.get("/api/version")
async def get_version(self, req: Request) -> Response:
Expand Down Expand Up @@ -167,7 +160,7 @@ async def upload_package(self, req: Request):
@routes.post("/api/jobs/")
@optional_utils.init_ray_and_catch_exceptions()
async def submit_job(self, req: Request) -> Response:
result = await self._parse_and_validate_request(req, JobSubmitRequest)
result = await parse_and_validate_request(req, JobSubmitRequest)
# Request parsing failed, returned with Response object.
if isinstance(result, Response):
return result
Expand Down Expand Up @@ -206,7 +199,9 @@ async def submit_job(self, req: Request) -> Response:
@optional_utils.init_ray_and_catch_exceptions()
async def stop_job(self, req: Request) -> Response:
job_or_submission_id = req.match_info["job_or_submission_id"]
job = await self.find_job_by_ids(job_or_submission_id)
job = await find_job_by_ids(
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
)
if not job:
return Response(
text=f"Job {job_or_submission_id} does not exist",
Expand Down Expand Up @@ -235,7 +230,9 @@ async def stop_job(self, req: Request) -> Response:
@optional_utils.init_ray_and_catch_exceptions()
async def get_job_info(self, req: Request) -> Response:
job_or_submission_id = req.match_info["job_or_submission_id"]
job = await self.find_job_by_ids(job_or_submission_id)
job = await find_job_by_ids(
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
)
if not job:
return Response(
text=f"Job {job_or_submission_id} does not exist",
Expand All @@ -250,7 +247,9 @@ async def get_job_info(self, req: Request) -> Response:
@routes.get("/api/jobs/")
@optional_utils.init_ray_and_catch_exceptions()
async def list_jobs(self, req: Request) -> Response:
driver_jobs, submission_job_drivers = await self._get_driver_jobs()
driver_jobs, submission_job_drivers = await get_driver_jobs(
self._dashboard_head.gcs_aio_client
)

submission_jobs = await self._job_manager.list_jobs()
submission_jobs = [
Expand All @@ -275,67 +274,13 @@ async def list_jobs(self, req: Request) -> Response:
content_type="application/json",
)

async def _get_driver_jobs(
self,
) -> Tuple[Dict[str, JobDetails], Dict[str, DriverInfo]]:
"""Returns a tuple of dictionaries related to drivers.
The first dictionary contains all driver jobs and is keyed by the job's id.
The second dictionary contains drivers that belong to submission jobs.
It's keyed by the submission job's submission id.
Only the last driver of a submission job is returned.
"""
request = gcs_service_pb2.GetAllJobInfoRequest()
reply = await self._gcs_job_info_stub.GetAllJobInfo(request, timeout=5)

jobs = {}
submission_job_drivers = {}
for job_table_entry in reply.job_info_list:
if job_table_entry.config.ray_namespace.startswith(
ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX
):
# Skip jobs in any _ray_internal_ namespace
continue
job_id = job_table_entry.job_id.hex()
metadata = dict(job_table_entry.config.metadata)
job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
if not job_submission_id:
driver = DriverInfo(
id=job_id,
node_ip_address=job_table_entry.driver_ip_address,
pid=job_table_entry.driver_pid,
)
job = JobDetails(
job_id=job_id,
type=JobType.DRIVER,
status=JobStatus.SUCCEEDED
if job_table_entry.is_dead
else JobStatus.RUNNING,
entrypoint="",
start_time=job_table_entry.start_time,
end_time=job_table_entry.end_time,
metadata=metadata,
runtime_env=RuntimeEnv.deserialize(
job_table_entry.config.runtime_env_info.serialized_runtime_env
).to_dict(),
driver_info=driver,
)
jobs[job_id] = job
else:
driver = DriverInfo(
id=job_id,
node_ip_address=job_table_entry.driver_ip_address,
pid=job_table_entry.driver_pid,
)
submission_job_drivers[job_submission_id] = driver

return jobs, submission_job_drivers

@routes.get("/api/jobs/{job_or_submission_id}/logs")
@optional_utils.init_ray_and_catch_exceptions()
async def get_job_logs(self, req: Request) -> Response:
job_or_submission_id = req.match_info["job_or_submission_id"]
job = await self.find_job_by_ids(job_or_submission_id)
job = await find_job_by_ids(
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
)
if not job:
return Response(
text=f"Job {job_or_submission_id} does not exist",
Expand All @@ -357,7 +302,9 @@ async def get_job_logs(self, req: Request) -> Response:
@optional_utils.init_ray_and_catch_exceptions()
async def tail_job_logs(self, req: Request) -> Response:
job_or_submission_id = req.match_info["job_or_submission_id"]
job = await self.find_job_by_ids(job_or_submission_id)
job = await find_job_by_ids(
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
)
if not job:
return Response(
text=f"Job {job_or_submission_id} does not exist",
Expand All @@ -380,10 +327,6 @@ async def run(self, server):
if not self._job_manager:
self._job_manager = JobManager(self._dashboard_head.gcs_aio_client)

self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel
)

@staticmethod
def is_minimal_module():
return False
Loading

0 comments on commit ce70b8b

Please sign in to comment.