Skip to content

Commit

Permalink
[core] move cluster_id to GcsClientOptions, and sends RPC to get it i…
Browse files Browse the repository at this point in the history
…f absent. (ray-project#46358)

Previously we have `cluster_id` in (Python)GcsClient::Connect as
argument. This arg is never changed throughout lifetime of a GcsClient
and should be placed to GcsClientOptions.

Renamed Python binding `GcsClientOptions.from_gcs_address` to
`GcsClientOptions.create`.

Also, today's C++ GcsClient never requests a cluster_id, so if it's
absent it never populates it in regular GCS requests. Adds a RPC in
Connect() time to get it if absent. To do that, also adds a `timeout`
arg in case GCS is down.

One caveat: if the GcsClient is in GCS itself, the RPC blocks GCS main
thread and it's deadlocked. So in GCS cluster_id *must* be populated
beforehand.
  • Loading branch information
rynewang authored Jul 9, 2024
1 parent c630ccc commit 6a7521c
Show file tree
Hide file tree
Showing 33 changed files with 294 additions and 101 deletions.
6 changes: 2 additions & 4 deletions cpp/src/ray/runtime/native_ray_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ NativeRayRuntime::NativeRayRuntime() {
if (bootstrap_address.empty()) {
bootstrap_address = GetNodeIpAddress();
}
bootstrap_address =
bootstrap_address + ":" + std::to_string(ConfigInternal::Instance().bootstrap_port);
global_state_accessor_ =
ProcessHelper::GetInstance().CreateGlobalStateAccessor(bootstrap_address);
global_state_accessor_ = ProcessHelper::GetInstance().CreateGlobalStateAccessor(
bootstrap_address, ConfigInternal::Instance().bootstrap_port);
}

const WorkerContext &NativeRayRuntime::GetWorkerContext() {
Expand Down
17 changes: 13 additions & 4 deletions cpp/src/ray/util/process_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ void ProcessHelper::StopRayNode() {
}

std::unique_ptr<ray::gcs::GlobalStateAccessor> ProcessHelper::CreateGlobalStateAccessor(
const std::string &gcs_address) {
ray::gcs::GcsClientOptions client_options(gcs_address);
const std::string &gcs_ip, int gcs_port) {
ray::gcs::GcsClientOptions client_options(gcs_ip,
gcs_port,
ray::ClusterID::Nil(),
/*allow_cluster_id_nil=*/true,
/*fetch_cluster_id_if_nil=*/false);
auto global_state_accessor =
std::make_unique<ray::gcs::GlobalStateAccessor>(client_options);
RAY_CHECK(global_state_accessor->Connect()) << "Failed to connect to GCS.";
Expand Down Expand Up @@ -93,7 +97,7 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback)
}

std::unique_ptr<ray::gcs::GlobalStateAccessor> global_state_accessor =
CreateGlobalStateAccessor(bootstrap_address);
CreateGlobalStateAccessor(bootstrap_ip, bootstrap_port);
if (ConfigInternal::Instance().worker_type == WorkerType::DRIVER) {
std::string node_to_connect;
auto status =
Expand All @@ -115,7 +119,12 @@ void ProcessHelper::RayStart(CoreWorkerOptions::TaskExecutionCallback callback)
ConfigInternal::Instance().UpdateSessionDir(session_dir);
}

gcs::GcsClientOptions gcs_options = gcs::GcsClientOptions(bootstrap_address);
gcs::GcsClientOptions gcs_options =
gcs::GcsClientOptions(bootstrap_ip,
bootstrap_port,
ClusterID::Nil(),
/*allow_cluster_id_nil=*/true,
/*fetch_cluster_id_if_nil=*/false);

CoreWorkerOptions options;
options.worker_type = ConfigInternal::Instance().worker_type;
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/ray/util/process_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ProcessHelper {
}

std::unique_ptr<ray::gcs::GlobalStateAccessor> CreateGlobalStateAccessor(
const std::string &gcs_address);
const std::string &gcs_ip, int gcs_port);

ProcessHelper(ProcessHelper const &) = delete;
void operator=(ProcessHelper const &) = delete;
Expand Down
4 changes: 3 additions & 1 deletion python/ray/_private/internal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def get_state_from_address(address=None):
address = services.canonicalize_bootstrap_address_or_die(address)

state = GlobalState()
options = GcsClientOptions.from_gcs_address(address)
options = GcsClientOptions.create(
address, None, allow_cluster_id_nil=True, fetch_cluster_id_if_nil=False
)
state._initialize_global_state(options)
return state

Expand Down
4 changes: 3 additions & 1 deletion python/ray/_private/metrics_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,9 @@ class PrometheusServiceDiscoveryWriter(threading.Thread):
"""

def __init__(self, gcs_address, temp_dir):
gcs_client_options = ray._raylet.GcsClientOptions.from_gcs_address(gcs_address)
gcs_client_options = ray._raylet.GcsClientOptions.create(
gcs_address, None, allow_cluster_id_nil=True, fetch_cluster_id_if_nil=False
)
self.gcs_address = gcs_address

ray._private.state.state._initialize_global_state(gcs_client_options)
Expand Down
7 changes: 5 additions & 2 deletions python/ray/_private/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,8 +1376,11 @@ def start_ray_processes(self):

if not self.head:
# Get the system config from GCS first if this is a non-head node.
gcs_options = ray._raylet.GcsClientOptions.from_gcs_address(
self.gcs_address
gcs_options = ray._raylet.GcsClientOptions.create(
self.gcs_address,
self.cluster_id.hex(),
allow_cluster_id_nil=False,
fetch_cluster_id_if_nil=False,
)
global_state = ray._private.state.GlobalState()
global_state._initialize_global_state(gcs_options)
Expand Down
11 changes: 9 additions & 2 deletions python/ray/_private/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,12 @@ def _build_python_executable_command_memory_profileable(


def _get_gcs_client_options(gcs_server_address):
return GcsClientOptions.from_gcs_address(gcs_server_address)
return GcsClientOptions.create(
gcs_server_address,
None,
allow_cluster_id_nil=True,
fetch_cluster_id_if_nil=False,
)


def serialize_config(config):
Expand Down Expand Up @@ -448,7 +453,9 @@ def wait_for_node(
TimeoutError: An exception is raised if the timeout expires before
the node appears in the client table.
"""
gcs_options = GcsClientOptions.from_gcs_address(gcs_address)
gcs_options = GcsClientOptions.create(
gcs_address, None, allow_cluster_id_nil=True, fetch_cluster_id_if_nil=False
)
global_state = ray._private.state.GlobalState()
global_state._initialize_global_state(gcs_options)
start_time = time.time()
Expand Down
7 changes: 5 additions & 2 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ class RayTestTimeoutException(Exception):


def make_global_state_accessor(ray_context):
gcs_options = GcsClientOptions.from_gcs_address(
ray_context.address_info["gcs_address"]
gcs_options = GcsClientOptions.create(
ray_context.address_info["gcs_address"],
None,
allow_cluster_id_nil=True,
fetch_cluster_id_if_nil=False,
)
global_state_accessor = GlobalStateAccessor(gcs_options)
global_state_accessor.connect()
Expand Down
14 changes: 12 additions & 2 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2249,7 +2249,12 @@ def connect(
assert worker.gcs_client is not None
_initialize_internal_kv(worker.gcs_client)
ray._private.state.state._initialize_global_state(
ray._raylet.GcsClientOptions.from_gcs_address(node.gcs_address)
ray._raylet.GcsClientOptions.create(
node.gcs_address,
node.cluster_id.hex(),
allow_cluster_id_nil=False,
fetch_cluster_id_if_nil=False,
)
)
worker.gcs_publisher = ray._raylet.GcsPublisher(address=worker.gcs_client.address)
# Initialize some fields.
Expand Down Expand Up @@ -2309,7 +2314,12 @@ def connect(
elif not LOCAL_MODE:
raise ValueError("Invalid worker mode. Expected DRIVER, WORKER or LOCAL.")

gcs_options = ray._raylet.GcsClientOptions.from_gcs_address(node.gcs_address)
gcs_options = ray._raylet.GcsClientOptions.create(
node.gcs_address,
node.cluster_id.hex(),
allow_cluster_id_nil=False,
fetch_cluster_id_if_nil=False,
)
if job_config is None:
job_config = ray.job_config.JobConfig()

Expand Down
17 changes: 11 additions & 6 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2688,10 +2688,9 @@ def _auto_reconnect(f):
try:
return f(self, *args, **kwargs)
except RpcError as e:
import grpc
if e.rpc_code in [
grpc.StatusCode.UNAVAILABLE.value[0],
grpc.StatusCode.UNKNOWN.value[0],
GRPC_STATUS_CODE_UNAVAILABLE,
GRPC_STATUS_CODE_UNKNOWN,
]:
if remaining_retry <= 0:
logger.error(
Expand Down Expand Up @@ -2726,7 +2725,14 @@ cdef class GcsClient:
nums_reconnect_retry=RayConfig.instance().nums_py_gcs_reconnect_retry(
),
cluster_id: str = None):
cdef GcsClientOptions gcs_options = GcsClientOptions.from_gcs_address(address)
cdef GcsClientOptions gcs_options
if cluster_id:
gcs_options = GcsClientOptions.create(
address, cluster_id, allow_cluster_id_nil=False,
fetch_cluster_id_if_nil=False)
else:
gcs_options = GcsClientOptions.create(
address, None, allow_cluster_id_nil=True, fetch_cluster_id_if_nil=True)
self.inner.reset(new CPythonGcsClient(dereference(gcs_options.native())))
self.address = address
self._nums_reconnect_retry = nums_reconnect_retry
Expand All @@ -2740,9 +2746,8 @@ cdef class GcsClient:
cdef:
int64_t timeout_ms = round(1000 * timeout_s) if timeout_s else -1
size_t num_retries = self._nums_reconnect_retry
CClusterID c_cluster_id = self.cluster_id.native()
with nogil:
status = self.inner.get().Connect(c_cluster_id, timeout_ms, num_retries)
status = self.inner.get().Connect(timeout_ms, num_retries)

check_status(status)

Expand Down
7 changes: 6 additions & 1 deletion python/ray/cluster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,12 @@ def add_node(self, wait: bool = True, **node_args):
)
self.webui_url = self.head_node.webui_url
# Init global state accessor when creating head node.
gcs_options = GcsClientOptions.from_gcs_address(node.gcs_address)
gcs_options = GcsClientOptions.create(
node.gcs_address,
None,
allow_cluster_id_nil=True,
fetch_cluster_id_if_nil=False,
)
self.global_state._initialize_global_state(gcs_options)
# Write the Ray cluster address for convenience in unit
# testing. ray.init() and ray.init(address="auto") will connect
Expand Down
5 changes: 3 additions & 2 deletions python/ray/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,14 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil:
UNIMPLEMENTED "grpc::StatusCode::UNIMPLEMENTED",

cdef cppclass CGcsClientOptions "ray::gcs::GcsClientOptions":
CGcsClientOptions(const c_string &gcs_address, int port)
CGcsClientOptions(
const c_string &gcs_address, int port, CClusterID cluster_id,
c_bool allow_cluster_id_nil, c_bool fetch_cluster_id_if_nil)

cdef cppclass CPythonGcsClient "ray::gcs::PythonGcsClient":
CPythonGcsClient(const CGcsClientOptions &options)

CRayStatus Connect(
const CClusterID &cluster_id,
int64_t timeout_ms,
size_t num_retries)
CRayStatus CheckAlive(
Expand Down
13 changes: 11 additions & 2 deletions python/ray/includes/common.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,22 @@ cdef class GcsClientOptions:
unique_ptr[CGcsClientOptions] inner

@classmethod
def from_gcs_address(cls, gcs_address):
def create(
cls, gcs_address, cluster_id_hex, allow_cluster_id_nil, fetch_cluster_id_if_nil
):
"""
Creates a GcsClientOption with a maybe-Nil cluster_id, and may fetch from GCS.
"""
cdef CClusterID c_cluster_id = CClusterID.Nil()
if cluster_id_hex:
c_cluster_id = CClusterID.FromHex(cluster_id_hex)
self = GcsClientOptions()
try:
ip, port = gcs_address.split(":", 2)
port = int(port)
self.inner.reset(
new CGcsClientOptions(ip, port))
new CGcsClientOptions(
ip, port, c_cluster_id, allow_cluster_id_nil, allow_cluster_id_nil))
except Exception:
raise ValueError(f"Invalid gcs_address: {gcs_address}")
return self
Expand Down
4 changes: 3 additions & 1 deletion python/ray/tests/test_object_spilling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def is_dir_empty(temp_folder, node_id, append_path=True):

def assert_no_thrashing(address):
state = ray._private.state.GlobalState()
options = GcsClientOptions.from_gcs_address(address)
options = GcsClientOptions.create(
address, None, allow_cluster_id_nil=True, fetch_cluster_id_if_nil=False
)
state._initialize_global_state(options)
summary = memory_summary(address=address, stats_only=True)
restored_bytes = 0
Expand Down
2 changes: 1 addition & 1 deletion src/mock/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MockGcsClient : public GcsClient {
public:
MOCK_METHOD(Status,
Connect,
(instrumented_io_context & io_service, const ClusterID &cluster_id),
(instrumented_io_context & io_service, int64_t timeout_ms),
(override));
MOCK_METHOD(void, Disconnect, (), (override));
MOCK_METHOD((std::pair<std::string, int>), GetGcsServerAddress, (), (const, override));
Expand Down
7 changes: 3 additions & 4 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,12 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_

gcs_client_ = std::make_shared<gcs::GcsClient>(options_.gcs_options, GetWorkerID());

RAY_CHECK_OK(gcs_client_->Connect(io_service_, options_.cluster_id));
RAY_CHECK_OK(gcs_client_->Connect(io_service_));
RegisterToGcs(options_.worker_launch_time_ms, options_.worker_launched_time_ms);

// Initialize the task state event buffer.
auto task_event_gcs_client = std::make_unique<gcs::GcsClient>(options_.gcs_options);
task_event_buffer_ =
std::make_unique<worker::TaskEventBufferImpl>(std::move(task_event_gcs_client));
task_event_buffer_ = std::make_unique<worker::TaskEventBufferImpl>(
std::make_shared<gcs::GcsClient>(options_.gcs_options));
if (RayConfig::instance().task_events_report_interval_ms() > 0) {
if (!task_event_buffer_->Start().ok()) {
RAY_CHECK(!task_event_buffer_->Enabled()) << "TaskEventBuffer should be disabled.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ inline gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env, jobject gcs_client_
env,
(jstring)env->GetObjectField(gcs_client_options, java_gcs_client_options_password));

return gcs::GcsClientOptions(ip + ":" + std::to_string(port));
return gcs::GcsClientOptions(ip,
port,
ray::ClusterID::Nil(),
/*allow_cluster_id_nil=*/true,
/*fetch_cluster_id_if_nil=*/false);
}

jobject ToJavaArgs(JNIEnv *env,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ Java_io_ray_runtime_gcs_GlobalStateAccessor_nativeCreateGlobalStateAccessor(
std::string bootstrap_address = JavaStringToNativeString(env, j_bootstrap_address);
std::string redis_password = JavaStringToNativeString(env, j_redis_passowrd);
gcs::GlobalStateAccessor *gcs_accessor = nullptr;
ray::gcs::GcsClientOptions client_options(bootstrap_address);
ray::gcs::GcsClientOptions client_options(bootstrap_address,
ray::ClusterID::Nil(),
/*allow_cluster_id_nil=*/true,
/*fetch_cluster_id_if_nil=*/false);
gcs_accessor = new gcs::GlobalStateAccessor(client_options);
return reinterpret_cast<jlong>(gcs_accessor);
}
Expand Down
28 changes: 15 additions & 13 deletions src/ray/core_worker/task_event_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void TaskProfileEvent::ToRpcTaskEvents(rpc::TaskEvents *rpc_task_events) {
event_entry->set_extra_data(std::move(extra_data_));
}

TaskEventBufferImpl::TaskEventBufferImpl(std::unique_ptr<gcs::GcsClient> gcs_client)
TaskEventBufferImpl::TaskEventBufferImpl(std::shared_ptr<gcs::GcsClient> gcs_client)
: work_guard_(boost::asio::make_work_guard(io_service_)),
periodical_runner_(io_service_),
gcs_client_(std::move(gcs_client)),
Expand All @@ -144,18 +144,6 @@ Status TaskEventBufferImpl::Start(bool auto_flush) {
status_events_.set_capacity(
RayConfig::instance().task_events_max_num_status_events_buffer_on_worker());

// Reporting to GCS, set up gcs client and and events flushing.
auto status = gcs_client_->Connect(io_service_);
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to connect to GCS, TaskEventBuffer will stop now. [status="
<< status.ToString() << "].";

enabled_ = false;
return status;
}

enabled_ = true;

io_thread_ = std::thread([this]() {
#ifndef _WIN32
// Block SIGINT and SIGTERM so they will be handled by the main thread.
Expand All @@ -170,6 +158,20 @@ Status TaskEventBufferImpl::Start(bool auto_flush) {
RAY_LOG(INFO) << "Task event buffer io service stopped.";
});

// Reporting to GCS, set up gcs client and and events flushing.
auto status = gcs_client_->Connect(io_service_);
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to connect to GCS, TaskEventBuffer will stop now. [status="
<< status.ToString() << "].";

enabled_ = false;
io_service_.stop();
io_thread_.join();
return status;
}

enabled_ = true;

if (!auto_flush) {
return Status::OK();
}
Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker/task_event_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class TaskEventBufferImpl : public TaskEventBuffer {
/// Constructor
///
/// \param gcs_client GCS client
TaskEventBufferImpl(std::unique_ptr<gcs::GcsClient> gcs_client);
TaskEventBufferImpl(std::shared_ptr<gcs::GcsClient> gcs_client);

~TaskEventBufferImpl() override;

Expand Down Expand Up @@ -381,7 +381,7 @@ class TaskEventBufferImpl : public TaskEventBuffer {
PeriodicalRunner periodical_runner_;

/// Client to the GCS used to push profile events to it.
std::unique_ptr<gcs::GcsClient> gcs_client_ ABSL_GUARDED_BY(mutex_);
std::shared_ptr<gcs::GcsClient> gcs_client_ ABSL_GUARDED_BY(mutex_);

/// True if the TaskEventBuffer is enabled.
std::atomic<bool> enabled_ = false;
Expand Down
Loading

0 comments on commit 6a7521c

Please sign in to comment.