Skip to content

Commit

Permalink
[Ray Client] Transfer dashboard_url over gRPC instead of ray.remote (r…
Browse files Browse the repository at this point in the history
…ay-project#30941)

The ray.remote call is spawning worker tasks on the head node even if their client doesn't do anything, spawning unexpected workers.

Note: dashboard_url behavior is already tested by test_client_builder
  • Loading branch information
ckw017 authored Dec 8, 2022
1 parent 794cfd9 commit 5f5dd14
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 4 deletions.
5 changes: 3 additions & 2 deletions python/ray/client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,9 @@ def connect(self) -> ClientContext:
ray_init_kwargs=self._remote_init_kwargs,
metadata=self._metadata,
)
get_dashboard_url = ray.remote(ray._private.worker.get_dashboard_url)
dashboard_url = ray.get(get_dashboard_url.options(num_cpus=0).remote())

dashboard_url = ray.util.client.ray._get_dashboard_url()

cxt = ClientContext(
dashboard_url=dashboard_url,
python_version=client_info_dict["python_version"],
Expand Down
38 changes: 38 additions & 0 deletions python/ray/tests/test_client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import ray
import ray.client_builder as client_builder
import ray.util.client.server.server as ray_client_server
from ray.experimental.state.api import list_workers
from ray._private.test_utils import (
run_string_as_driver,
run_string_as_driver_nonblocking,
wait_for_condition,
)
import time


@pytest.mark.parametrize(
Expand Down Expand Up @@ -419,6 +421,42 @@ def test_client_deprecation_warn():
subprocess.check_output("ray stop --force", shell=True)


@pytest.mark.parametrize(
"call_ray_start",
[
"ray start --head --num-cpus=2 --min-worker-port=0 --max-worker-port=0 "
"--port 0 --ray-client-server-port=50056"
],
indirect=True,
)
def test_worker_processes(call_ray_start):
"""
Test that no workers are spawned until a remote function is called.
"""
ray.init("ray://localhost:50056")

# Check for 10 seconds that no workers spawned after connecting
for _ in range(10):
workers = list_workers()
non_driver_workers = [w for w in workers if w.get("worker_type") != "DRIVER"]
assert len(non_driver_workers) == 0, workers
time.sleep(1)

@ray.remote(num_cpus=2)
def f():
return 42

assert ray.get(f.remote()) == 42
time.sleep(3)

# 2 worker processes should have spawned to accommodate the remote func
for _ in range(10):
workers = list_workers()
non_driver_workers = [w for w in workers if w.get("worker_type") != "DRIVER"]
assert len(non_driver_workers) == 2, workers
time.sleep(1)


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
12 changes: 11 additions & 1 deletion python/ray/tests/test_client_reconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,17 @@ def ListNamedActors(
return self._call_inner_function(request, context, "ListNamedActors")

def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse:
return self._call_inner_function(request, context, "ClusterInfo")
# Cluster info is currently used for health checks and isn't retried, so
# don't inject errors.
# TODO(ckw): update ClusterInfo so that retries are only skipped for PING
try:
return self.stub.ClusterInfo(
request, metadata=context.invocation_metadata()
)
except grpc.RpcError as e:
context.set_code(e.code())
context.set_details(e.details())
raise

def Terminate(self, req, context=None):
return self._call_inner_function(req, context, "Terminate")
Expand Down
2 changes: 1 addition & 1 deletion python/ray/util/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2022-10-05"
CURRENT_PROTOCOL_VERSION = "2022-12-06"


class _ClientContext:
Expand Down
7 changes: 7 additions & 0 deletions python/ray/util/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,10 @@ def _register_callback(
self, ref: "ClientObjectRef", callback: Callable[["DataResponse"], None]
) -> None:
self.worker.register_callback(ref, callback)

def _get_dashboard_url(self) -> str:
import ray.core.generated.ray_client_pb2 as ray_client_pb2

return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.DASHBOARD_URL
).get("dashboard_url", "")
2 changes: 2 additions & 0 deletions python/ray/util/client/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def _return_debug_cluster_info(self, request, context=None) -> str:
data = ray.timeline()
elif request.type == ray_client_pb2.ClusterInfoType.PING:
data = {}
elif request.type == ray_client_pb2.ClusterInfoType.DASHBOARD_URL:
data = {"dashboard_url": ray._private.worker.get_dashboard_url()}
else:
raise TypeError("Unsupported cluster info type")
return json.dumps(data)
Expand Down
1 change: 1 addition & 0 deletions src/ray/protobuf/ray_client.proto
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ message ClusterInfoType {
RUNTIME_CONTEXT = 4;
TIMELINE = 5;
PING = 6;
DASHBOARD_URL = 7;
}
}

Expand Down

0 comments on commit 5f5dd14

Please sign in to comment.