Skip to content

Commit

Permalink
[Core] Make "import ray" work without grpcio (ray-project#35737)
Browse files Browse the repository at this point in the history
Why are these changes needed?

Import cleanups to push grpc imports out of the critical path of Ray

Related issue number

Next step in ray-project#35472

Signed-off-by: Philipp Moritz <[email protected]>
Co-authored-by: SangBin Cho <[email protected]>
  • Loading branch information
pcmoritz and rkooo567 authored May 25, 2023
1 parent 3c03523 commit ed661cd
Show file tree
Hide file tree
Showing 17 changed files with 90 additions and 59 deletions.
1 change: 1 addition & 0 deletions dashboard/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from queue import Queue

import ray._private.services
import ray._private.tls_utils
import ray._private.utils
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
Expand Down
4 changes: 1 addition & 3 deletions python/ray/_private/import_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import traceback
from collections import defaultdict

import grpc

import ray
import ray._private.profiling as profiling
from ray import JobID
Expand Down Expand Up @@ -35,7 +33,7 @@ def __init__(self, worker, mode, threads_stopped):
self.gcs_client = worker.gcs_client
self.subscriber = worker.gcs_function_key_subscriber
self.subscriber.subscribe()
self.exception_type = grpc.RpcError
self.exception_type = ray.exceptions.RpcError
self.threads_stopped = threads_stopped
self.imported_collision_identifiers = defaultdict(int)
self.t = None
Expand Down
3 changes: 1 addition & 2 deletions python/ray/_private/runtime_env/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR,
RAY_RUNTIME_ENV_IGNORE_GITIGNORE,
)
from ray._private.gcs_utils import GcsAioClient
from ray._private.thirdparty.pathspec import PathSpec
from ray.experimental.internal_kv import (
_internal_kv_exists,
Expand Down Expand Up @@ -594,7 +593,7 @@ def get_local_dir_from_uri(uri: str, base_directory: str) -> Path:
async def download_and_unpack_package(
pkg_uri: str,
base_directory: str,
gcs_aio_client: Optional[GcsAioClient] = None,
gcs_aio_client: Optional["GcsAioClient"] = None, # noqa: F821
logger: Optional[logging.Logger] = default_logger,
) -> str:
"""Download the package corresponding to this URI and unpack it if zipped.
Expand Down
5 changes: 3 additions & 2 deletions python/ray/_private/runtime_env/py_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from types import ModuleType
from typing import Any, Dict, List, Optional

from ray._private.gcs_utils import GcsAioClient
from ray._private.runtime_env.conda_utils import exec_cmd_stream_to_logger
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.packaging import (
Expand Down Expand Up @@ -141,7 +140,9 @@ class PyModulesPlugin(RuntimeEnvPlugin):

name = "py_modules"

def __init__(self, resources_dir: str, gcs_aio_client: GcsAioClient):
def __init__(
self, resources_dir: str, gcs_aio_client: "GcsAioClient" # noqa: F821
):
self._resources_dir = os.path.join(resources_dir, "py_modules_files")
self._gcs_aio_client = gcs_aio_client
try_to_create_directory(self._resources_dir)
Expand Down
5 changes: 3 additions & 2 deletions python/ray/_private/runtime_env/working_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

from ray._private.gcs_utils import GcsAioClient
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.packaging import (
Protocol,
Expand Down Expand Up @@ -120,7 +119,9 @@ class WorkingDirPlugin(RuntimeEnvPlugin):

name = "working_dir"

def __init__(self, resources_dir: str, gcs_aio_client: GcsAioClient):
def __init__(
self, resources_dir: str, gcs_aio_client: "GcsAioClient" # noqa: F821
):
self._resources_dir = os.path.join(resources_dir, "working_dir_files")
self._gcs_aio_client = gcs_aio_client
try_to_create_directory(self._resources_dir)
Expand Down
3 changes: 1 addition & 2 deletions python/ray/_private/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import ray._private.utils
import ray.cloudpickle as pickle
from ray._private import ray_constants
from ray._private.gcs_utils import ErrorType
from ray._raylet import (
MessagePackSerializedObject,
MessagePackSerializer,
Expand All @@ -18,7 +17,7 @@
split_buffer,
unpack_pickle5_buffers,
)
from ray.core.generated.common_pb2 import RayErrorInfo
from ray.core.generated.common_pb2 import ErrorType, RayErrorInfo
from ray.exceptions import (
ActorPlacementGroupRemoved,
ActorUnschedulableError,
Expand Down
36 changes: 18 additions & 18 deletions python/ray/_private/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from google.protobuf.json_format import MessageToDict

import ray
import ray._private.gcs_utils as gcs_utils
from ray._private.client_mode_hook import client_mode_hook
from ray._private.resource_spec import NODE_ID_PREFIX
from ray._private.utils import binary_to_hex, decode, hex_to_binary
from ray._raylet import GlobalStateAccessor
from ray.core.generated import common_pb2
from ray.core.generated import gcs_pb2
from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -92,13 +92,13 @@ def actor_table(self, actor_id):
if actor_info is None:
return {}
else:
actor_table_data = gcs_utils.ActorTableData.FromString(actor_info)
actor_table_data = gcs_pb2.ActorTableData.FromString(actor_info)
return self._gen_actor_info(actor_table_data)
else:
actor_table = self.global_state_accessor.get_actor_table()
results = {}
for i in range(len(actor_table)):
actor_table_data = gcs_utils.ActorTableData.FromString(actor_table[i])
actor_table_data = gcs_pb2.ActorTableData.FromString(actor_table[i])
results[
binary_to_hex(actor_table_data.actor_id)
] = self._gen_actor_info(actor_table_data)
Expand Down Expand Up @@ -167,7 +167,7 @@ def job_table(self):

results = []
for i in range(len(job_table)):
entry = gcs_utils.JobTableData.FromString(job_table[i])
entry = gcs_pb2.JobTableData.FromString(job_table[i])
job_info = {}
job_info["JobID"] = entry.job_id.hex()
job_info["DriverIPAddress"] = entry.driver_ip_address
Expand Down Expand Up @@ -215,7 +215,7 @@ def profile_events(self):
result = defaultdict(list)
task_events = self.global_state_accessor.get_task_events()
for i in range(len(task_events)):
event = gcs_utils.TaskEvents.FromString(task_events[i])
event = common_pb2.TaskEvents.FromString(task_events[i])
profile = event.profile_events
if not profile:
continue
Expand Down Expand Up @@ -252,7 +252,7 @@ def get_placement_group_by_name(self, placement_group_name, ray_namespace):
if placement_group_info is None:
return None
else:
placement_group_table_data = gcs_utils.PlacementGroupTableData.FromString(
placement_group_table_data = gcs_pb2.PlacementGroupTableData.FromString(
placement_group_info
)
return self._gen_placement_group_info(placement_group_table_data)
Expand All @@ -270,7 +270,7 @@ def placement_group_table(self, placement_group_id=None):
if placement_group_info is None:
return {}
else:
placement_group_info = gcs_utils.PlacementGroupTableData.FromString(
placement_group_info = gcs_pb2.PlacementGroupTableData.FromString(
placement_group_info
)
return self._gen_placement_group_info(placement_group_info)
Expand All @@ -280,8 +280,8 @@ def placement_group_table(self, placement_group_id=None):
)
results = {}
for placement_group_info in placement_group_table:
placement_group_table_data = (
gcs_utils.PlacementGroupTableData.FromString(placement_group_info)
placement_group_table_data = gcs_pb2.PlacementGroupTableData.FromString(
placement_group_info
)
placement_group_id = binary_to_hex(
placement_group_table_data.placement_group_id
Expand All @@ -297,11 +297,11 @@ def _gen_placement_group_info(self, placement_group_info):
from ray.core.generated.common_pb2 import PlacementStrategy

def get_state(state):
if state == gcs_utils.PlacementGroupTableData.PENDING:
if state == gcs_pb2.PlacementGroupTableData.PENDING:
return "PENDING"
elif state == gcs_utils.PlacementGroupTableData.CREATED:
elif state == gcs_pb2.PlacementGroupTableData.CREATED:
return "CREATED"
elif state == gcs_utils.PlacementGroupTableData.RESCHEDULING:
elif state == gcs_pb2.PlacementGroupTableData.RESCHEDULING:
return "RESCHEDULING"
else:
return "REMOVED"
Expand Down Expand Up @@ -602,10 +602,10 @@ def workers(self):
worker_table = self.global_state_accessor.get_worker_table()
workers_data = {}
for i in range(len(worker_table)):
worker_table_data = gcs_utils.WorkerTableData.FromString(worker_table[i])
worker_table_data = gcs_pb2.WorkerTableData.FromString(worker_table[i])
if (
worker_table_data.is_alive
and worker_table_data.worker_type == gcs_utils.WORKER
and worker_table_data.worker_type == common_pb2.WORKER
):
worker_id = binary_to_hex(worker_table_data.worker_address.worker_id)
worker_info = worker_table_data.worker_info
Expand All @@ -629,14 +629,14 @@ def add_worker(self, worker_id, worker_type, worker_info):
Args:
worker_id: ID of this worker. Type is bytes.
worker_type: Type of this worker. Value is gcs_utils.DRIVER or
gcs_utils.WORKER.
worker_type: Type of this worker. Value is common_pb2.DRIVER or
common_pb2.WORKER.
worker_info: Info of this worker. Type is dict{str: str}.
Returns:
Is operation success
"""
worker_data = gcs_utils.WorkerTableData()
worker_data = gcs_pb2.WorkerTableData()
worker_data.is_alive = True
worker_data.worker_address.worker_id = worker_id
worker_data.worker_type = worker_type
Expand Down Expand Up @@ -680,7 +680,7 @@ def _available_resources_per_node(self):
self.global_state_accessor.get_all_available_resources()
)
for available_resource in all_available_resources:
message = gcs_utils.AvailableResources.FromString(available_resource)
message = gcs_pb2.AvailableResources.FromString(available_resource)
# Calculate available resources for this node.
dynamic_resources = {}
for resource_id, capacity in message.resources_available.items():
Expand Down
3 changes: 2 additions & 1 deletion python/ray/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

import ray
import ray._private.ray_constants as ray_constants
from ray._private.tls_utils import load_certs_from_env
from ray.core.generated.runtime_env_common_pb2 import (
RuntimeEnvInfo as ProtoRuntimeEnvInfo,
)
Expand Down Expand Up @@ -1302,6 +1301,8 @@ def init_grpc_channel(
except ImportError:
from grpc.experimental import aio as aiogrpc

from ray._private.tls_utils import load_certs_from_env

grpc_module = aiogrpc if asynchronous else grpc

options = options or []
Expand Down
30 changes: 17 additions & 13 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from typing_extensions import Literal, Protocol

import ray
import ray._private.gcs_utils as gcs_utils
import ray._private.import_thread as import_thread
import ray._private.node
import ray._private.parameter
Expand All @@ -64,11 +63,7 @@
from ray._private import ray_option_utils
from ray._private.client_mode_hook import client_mode_hook
from ray._private.function_manager import FunctionActorManager, make_function_table_key
from ray._private.gcs_pubsub import (
GcsErrorSubscriber,
GcsFunctionKeySubscriber,
GcsLogSubscriber,
)

from ray._private.inspect_util import is_cython
from ray._private.ray_logging import (
global_worker_stdstream_dispatcher,
Expand Down Expand Up @@ -1742,11 +1737,13 @@ def sigterm_handler(signum, frame):


def custom_excepthook(type, value, tb):
import ray.core.generated.common_pb2 as common_pb2

# If this is a driver, push the exception to GCS worker table.
if global_worker.mode == SCRIPT_MODE and hasattr(global_worker, "worker_id"):
error_message = "".join(traceback.format_tb(tb))
worker_id = global_worker.worker_id
worker_type = gcs_utils.DRIVER
worker_type = common_pb2.DRIVER
worker_info = {"exception": error_message}

ray._private.state.state._check_connected()
Expand Down Expand Up @@ -2236,6 +2233,12 @@ def connect(
worker_launch_time_ms,
worker_launched_time_ms,
)
# The following will be fixed with https://github.com/ray-project/ray/pull/35094
from ray._private.gcs_pubsub import (
GcsErrorSubscriber,
GcsFunctionKeySubscriber,
GcsLogSubscriber,
)

# Notify raylet that the core worker is ready.
worker.core_worker.notify_raylet()
Expand Down Expand Up @@ -2586,14 +2589,15 @@ def put(
elif isinstance(_owner, ray.actor.ActorHandle):
# Ensure `ray._private.state.state.global_state_accessor` is not None
ray._private.state.state._check_connected()
owner_address = gcs_utils.ActorTableData.FromString(
ray._private.state.state.global_state_accessor.get_actor_info(
_owner._actor_id
serialize_owner_address = (
ray._raylet._get_actor_serialized_owner_address_or_none(
ray._private.state.state.global_state_accessor.get_actor_info(
_owner._actor_id
)
)
).address
if len(owner_address.worker_id) == 0:
)
if not serialize_owner_address:
raise RuntimeError(f"{_owner} is not alive, it's worker_id is empty!")
serialize_owner_address = owner_address.SerializeToString()
else:
raise TypeError(f"Expect an `ray.actor.ActorHandle`, but got: {type(_owner)}")

Expand Down
17 changes: 15 additions & 2 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ from ray.includes.common cimport (
CWorkerExitType,
CRayObject,
CRayStatus,
CActorTableData,
CErrorTableData,
CGcsClientOptions,
CGcsNodeInfo,
Expand Down Expand Up @@ -157,7 +158,7 @@ from ray._private.async_compat import (
is_async_func
)
from ray._private.client_mode_hook import disable_client_hook
import ray._private.gcs_utils as gcs_utils
import ray.core.generated.common_pb2 as common_pb2
import ray._private.memory_monitor as memory_monitor
import ray._private.profiling as profiling
from ray._private.utils import decode, DeferSigint
Expand Down Expand Up @@ -507,6 +508,18 @@ cdef c_vector[CObjectID] ObjectRefsToVector(object_refs):
return result


def _get_actor_serialized_owner_address_or_none(actor_table_data: bytes):
cdef:
CActorTableData data

data.ParseFromString(actor_table_data)

if data.address().worker_id() == b"":
return None
else:
return data.address().SerializeAsString()


def compute_task_id(ObjectRef object_ref):
return TaskID(object_ref.native().TaskId().Binary())

Expand Down Expand Up @@ -3724,7 +3737,7 @@ cdef class CoreWorker:
# the job config will not change after a job is submitted.
if self.job_config is None:
c_job_config = CCoreWorkerProcess.GetCoreWorker().GetJobConfig()
self.job_config = gcs_utils.JobConfig()
self.job_config = common_pb2.JobConfig()
self.job_config.ParseFromString(c_job_config.SerializeAsString())
return self.job_config

Expand Down
3 changes: 3 additions & 0 deletions python/ray/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,9 @@ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
void set_actor_name(const c_string &actor_name)
void set_task_name(const c_string &task_name)

cdef cppclass CActorTableData "ray::rpc::ActorTableData":
CAddress address() const
void ParseFromString(const c_string &serialized)

cdef extern from "ray/common/task/task_spec.h" nogil:
cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup":
Expand Down
2 changes: 0 additions & 2 deletions python/ray/includes/object_ref.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ import threading
from typing import Callable, Any, Union

import ray
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import cython
import ray.util.client as client

logger = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit ed661cd

Please sign in to comment.