Skip to content

Commit

Permalink
[core] Add template parameter for RPC auth in GCS (ray-project#37307)
Browse files Browse the repository at this point in the history
Signed-off-by: vitsai <[email protected]>
Signed-off-by: vitsai <[email protected]>
  • Loading branch information
vitsai authored Aug 25, 2023
1 parent 7ffda41 commit 3f11cf4
Show file tree
Hide file tree
Showing 22 changed files with 330 additions and 124 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ cc_library(
"//src/ray/common:asio",
"//src/ray/common:grpc_util",
"//src/ray/common:id",
"//src/ray/common:ray_config",
"//src/ray/common:status",
"@com_github_grpc_grpc//:grpc++",
"@com_github_grpc_grpc//:grpc++_reflection",
Expand Down
40 changes: 40 additions & 0 deletions python/ray/tests/test_gcs_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray._private import ray_constants
from ray._private.test_utils import (
convert_actor_state,
enable_external_redis,
generate_system_config_map,
wait_for_condition,
wait_for_pid_to_exit,
Expand Down Expand Up @@ -828,6 +829,45 @@ def f():
wait_for_pid_to_exit(gcs_server_pid, 10000)


@pytest.mark.parametrize(
"ray_start_regular",
[
generate_system_config_map(
enable_cluster_auth=True,
)
],
indirect=True,
)
def test_cluster_id(ray_start_regular):
# Kill GCS and check that raylets kill themselves when not backed by Redis,
# and stay alive when backed by Redis.
# Raylets should kill themselves due to cluster ID mismatch in the
# non-persisted case.
raylet_proc = ray._private.worker._global_node.all_processes[
ray_constants.PROCESS_TYPE_RAYLET
][0].process

def check_raylet_healthy():
return raylet_proc.poll() is None

wait_for_condition(lambda: check_raylet_healthy())
for i in range(10):
assert check_raylet_healthy()
sleep(1)

ray._private.worker._global_node.kill_gcs_server()
ray._private.worker._global_node.start_gcs_server()

if not enable_external_redis():
# Waiting for raylet to become unhealthy
wait_for_condition(lambda: not check_raylet_healthy())
else:
# Waiting for raylet to stay healthy
for i in range(10):
assert check_raylet_healthy()
sleep(1)


@pytest.mark.parametrize(
"ray_start_regular_with_external_redis",
[
Expand Down
3 changes: 3 additions & 0 deletions src/ray/common/ray_config_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ RAY_CONFIG(bool, event_stats_metrics, false)
/// TODO(ekl) remove this after Ray 1.8
RAY_CONFIG(bool, legacy_scheduler_warnings, false)

/// Whether to enable cluster authentication.
RAY_CONFIG(bool, enable_cluster_auth, false)

/// The interval of periodic event loop stats print.
/// -1 means the feature is disabled. In this case, stats are available to
/// debug_state_*.txt
Expand Down
3 changes: 3 additions & 0 deletions src/ray/common/status.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ namespace ray {
#define STATUS_CODE_NOT_FOUND "NotFound"
#define STATUS_CODE_DISCONNECTED "Disconnected"
#define STATUS_CODE_SCHEDULING_CANCELLED "SchedulingCancelled"
#define STATUS_CODE_AUTH_ERROR "AuthError"
// object store status
#define STATUS_CODE_OBJECT_EXISTS "ObjectExists"
#define STATUS_CODE_OBJECT_NOT_FOUND "ObjectNotFound"
Expand Down Expand Up @@ -114,6 +115,7 @@ std::string Status::CodeAsString() const {
{StatusCode::TransientObjectStoreFull, STATUS_CODE_TRANSIENT_OBJECT_STORE_FULL},
{StatusCode::GrpcUnavailable, STATUS_CODE_GRPC_UNAVAILABLE},
{StatusCode::GrpcUnknown, STATUS_CODE_GRPC_UNKNOWN},
{StatusCode::AuthError, STATUS_CODE_AUTH_ERROR},
};

auto it = code_to_str.find(code());
Expand Down Expand Up @@ -149,6 +151,7 @@ StatusCode Status::StringToCode(const std::string &str) {
{STATUS_CODE_OBJECT_UNKNOWN_OWNER, StatusCode::ObjectUnknownOwner},
{STATUS_CODE_OBJECT_STORE_FULL, StatusCode::ObjectStoreFull},
{STATUS_CODE_TRANSIENT_OBJECT_STORE_FULL, StatusCode::TransientObjectStoreFull},
{STATUS_CODE_AUTH_ERROR, StatusCode::AuthError},
};

auto it = str_to_code.find(str);
Expand Down
10 changes: 8 additions & 2 deletions src/ray/common/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ enum class StatusCode : char {
ObjectUnknownOwner = 29,
RpcError = 30,
OutOfResource = 31,
// Meaning the ObjectRefStream reaches to the end of stream.
ObjectRefEndOfStream = 32
ObjectRefEndOfStream = 32,
AuthError = 33,
};

#if defined(__clang__)
Expand Down Expand Up @@ -252,6 +252,10 @@ class RAY_EXPORT Status {
return Status(StatusCode::OutOfResource, msg);
}

static Status AuthError(const std::string &msg) {
return Status(StatusCode::AuthError, msg);
}

static StatusCode StringToCode(const std::string &str);

// Returns true iff the status indicates success.
Expand Down Expand Up @@ -303,6 +307,8 @@ class RAY_EXPORT Status {

bool IsOutOfResource() const { return code() == StatusCode::OutOfResource; }

bool IsAuthError() const { return code() == StatusCode::AuthError; }

// Return a string representation of this status suitable for printing.
// Returns the string "OK" for success.
std::string ToString() const;
Expand Down
7 changes: 7 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
/// Connect to GCS Service. Non-thread safe.
/// This function must be called before calling other functions.
/// \param instrumented_io_context IO execution service.
/// \param cluster_id Optional cluster ID to provide to the client.
///
/// \return Status
virtual Status Connect(instrumented_io_context &io_service,
Expand Down Expand Up @@ -156,6 +157,11 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
return *placement_group_accessor_;
}

const ClusterID &GetClusterId() {
RAY_CHECK(client_call_manager_) << "Cannot retrieve cluster ID before it is set.";
return client_call_manager_->GetClusterId();
}

/// Get the sub-interface for accessing worker information in GCS.
/// This function is thread safe.
virtual InternalKVAccessor &InternalKV() { return *internal_kv_accessor_; }
Expand All @@ -175,6 +181,7 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
std::unique_ptr<WorkerInfoAccessor> worker_accessor_;
std::unique_ptr<PlacementGroupInfoAccessor> placement_group_accessor_;
std::unique_ptr<InternalKVAccessor> internal_kv_accessor_;

std::unique_ptr<TaskInfoAccessor> task_accessor_;

private:
Expand Down
57 changes: 55 additions & 2 deletions src/ray/gcs/gcs_client/test/gcs_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
// Create GCS client.
gcs::GcsClientOptions options("127.0.0.1:5397");
gcs_client_ = std::make_unique<gcs::GcsClient>(options);
RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_));
ReconnectClient();
}

void TearDown() override {
Expand All @@ -115,6 +115,15 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
rpc::ResetServerCallExecutor();
}

void ReconnectClient() {
ClusterID cluster_id = gcs_server_->rpc_server_.GetClusterId();
RAY_CHECK_OK(gcs_client_->Connect(*client_io_service_, cluster_id));
}

void StampContext(grpc::ClientContext &context) {
context.AddMetadata(kClusterIdKey, gcs_client_->GetClusterId().Hex());
}

void RestartGcsServer() {
RAY_LOG(INFO) << "Stopping GCS service, port = " << gcs_server_->GetPort();
gcs_server_->Stop();
Expand Down Expand Up @@ -142,11 +151,14 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
grpc::InsecureChannelCredentials());
auto stub = rpc::NodeInfoGcsService::NewStub(std::move(channel));
grpc::ClientContext context;
StampContext(context);
context.set_deadline(std::chrono::system_clock::now() + 1s);
const rpc::CheckAliveRequest request;
rpc::CheckAliveReply reply;
auto status = stub->CheckAlive(&context, request, &reply);
if (!status.ok()) {
// If it is in memory, we don't have the new token until we connect again.
if (!((!no_redis_ && status.ok()) ||
(no_redis_ && GrpcStatusToRayStatus(status).IsAuthError()))) {
RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " "
<< status.error_message();
continue;
Expand Down Expand Up @@ -923,6 +935,7 @@ TEST_P(GcsClientTest, TestEvictExpiredDestroyedActors) {

// Restart GCS.
RestartGcsServer();
ReconnectClient();

for (int index = 0; index < actor_count; ++index) {
auto actor_table_data = Mocker::GenActorTableData(job_id);
Expand All @@ -945,9 +958,49 @@ TEST_P(GcsClientTest, TestEvictExpiredDestroyedActors) {
}
}

TEST_P(GcsClientTest, TestGcsEmptyAuth) {
RayConfig::instance().initialize(R"({"enable_cluster_auth": true})");
// Restart GCS.
RestartGcsServer();
auto channel = grpc::CreateChannel(absl::StrCat("127.0.0.1:", gcs_server_->GetPort()),
grpc::InsecureChannelCredentials());
auto stub = rpc::NodeInfoGcsService::NewStub(std::move(channel));
grpc::ClientContext context;
StampContext(context);
context.set_deadline(std::chrono::system_clock::now() + 1s);
const rpc::GetClusterIdRequest request;
rpc::GetClusterIdReply reply;
auto status = stub->GetClusterId(&context, request, &reply);

// We expect the wrong cluster ID
EXPECT_TRUE(GrpcStatusToRayStatus(status).IsAuthError());
}

TEST_P(GcsClientTest, TestGcsAuth) {
RayConfig::instance().initialize(R"({"enable_cluster_auth": true})");
// Restart GCS.
RestartGcsServer();
auto node_info = Mocker::GenNodeInfo();
if (!no_redis_) {
// If we are backed by Redis, we can reuse cluster ID, so the RPC passes.
EXPECT_TRUE(RegisterNode(*node_info));
return;
}

// If we are not backed by Redis, we need to first fetch
// the new cluster ID, so we expect failure before success.
EXPECT_FALSE(RegisterNode(*node_info));
ReconnectClient();
EXPECT_TRUE(RegisterNode(*node_info));
}

TEST_P(GcsClientTest, TestEvictExpiredDeadNodes) {
RayConfig::instance().initialize(R"({"enable_cluster_auth": true})");
// Restart GCS.
RestartGcsServer();
if (RayConfig::instance().gcs_storage() == gcs::GcsServer::kInMemoryStorage) {
ReconnectClient();
}

// Simulate the scenario of node dead.
int node_count = RayConfig::instance().maximum_gcs_dead_node_cached_count();
Expand Down
3 changes: 2 additions & 1 deletion src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config,
rpc_server_(config.grpc_server_name,
config.grpc_server_port,
config.node_ip_address == "127.0.0.1",
ClusterID::Nil(),
config.grpc_server_thread_num,
/*keepalive_time_ms=*/RayConfig::instance().grpc_keepalive_time_ms()),
client_call_manager_(main_service,
Expand Down Expand Up @@ -167,7 +168,7 @@ void GcsServer::GetOrGenerateClusterId(
kClusterIdKey,
cluster_id.Binary(),
false,
[&cluster_id,
[cluster_id,
continuation = std::move(continuation)](bool added_entry) mutable {
RAY_CHECK(added_entry) << "Failed to persist new cluster ID!";
continuation(cluster_id);
Expand Down
4 changes: 4 additions & 0 deletions src/ray/gcs/gcs_server/gcs_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
namespace ray {
using raylet::ClusterTaskManager;
using raylet::NoopLocalTaskManager;

class GcsClientTest;
namespace gcs {

struct GcsServerConfig {
Expand Down Expand Up @@ -172,6 +174,8 @@ class GcsServer {
/// Initialize monitor service.
void InitMonitorServer();

friend class ray::GcsClientTest;

private:
/// Gets the type of KV storage to use from config.
StorageType GetStorageType() const;
Expand Down
1 change: 1 addition & 0 deletions src/ray/object_manager/object_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ ObjectManager::ObjectManager(
object_manager_server_("ObjectManager",
config_.object_manager_port,
config_.object_manager_address == "127.0.0.1",
ClusterID::Nil(),
config_.rpc_service_threads_number),
object_manager_service_(rpc_service_, *this),
client_call_manager_(
Expand Down
21 changes: 12 additions & 9 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -550,15 +550,18 @@ ray::Status NodeManager::RegisterGcs() {
checking = true;
RAY_CHECK_OK(gcs_client_->Nodes().AsyncCheckSelfAlive(
// capture checking ptr here because vs17 fail to compile
[checking_ptr = &checking](auto status, auto alive) mutable {
if (status.ok()) {
if (!alive) {
// GCS think this raylet is dead. Fail the node
RAY_LOG(FATAL)
<< "GCS consider this node to be dead. This may happen when "
<< "GCS is not backed by a DB and restarted or there is data loss "
<< "in the DB.";
}
[this, checking_ptr = &checking](auto status, auto alive) mutable {
if ((status.ok() && !alive)) {
// GCS think this raylet is dead. Fail the node
RAY_LOG(FATAL)
<< "GCS consider this node to be dead. This may happen when "
<< "GCS is not backed by a DB and restarted or there is data loss "
<< "in the DB.";
} else if (status.IsAuthError()) {
RAY_LOG(FATAL)
<< "GCS returned an authentication error. This may happen when "
<< "GCS is not backed by a DB and restarted or there is data loss "
<< "in the DB. Local cluster ID: " << gcs_client_->GetClusterId();
}
*checking_ptr = false;
},
Expand Down
3 changes: 2 additions & 1 deletion src/ray/rpc/agent_manager/agent_manager_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace ray {
namespace rpc {

#define RAY_AGENT_MANAGER_RPC_HANDLERS \
RPC_SERVICE_HANDLER(AgentManagerService, RegisterAgent, -1)
RPC_SERVICE_HANDLER_CUSTOM_AUTH( \
AgentManagerService, RegisterAgent, -1, AuthType::NO_AUTH)

/// Implementations of the `AgentManagerGrpcService`, check interface in
/// `src/ray/protobuf/agent_manager.proto`.
Expand Down
4 changes: 4 additions & 0 deletions src/ray/rpc/client_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class ClientCallManager {
///
/// \param[in] main_service The main event loop, to which the callback functions will be
/// posted.
///
explicit ClientCallManager(instrumented_io_context &main_service,
const ClusterID &cluster_id = ClusterID::Nil(),
int num_threads = 1,
Expand Down Expand Up @@ -267,6 +268,9 @@ class ClientCallManager {
return call;
}

/// Get the cluster ID.
const ClusterID &GetClusterId() const { return cluster_id_; }

/// Get the main service of this rpc.
instrumented_io_context &GetMainService() { return main_service_; }

Expand Down
2 changes: 1 addition & 1 deletion src/ray/rpc/gcs_server/gcs_rpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class GcsRpcClient {
public:
/// Constructor. GcsRpcClient is not thread safe.
///
/// \param[in] address Address of gcs server.
// \param[in] address Address of gcs server.
/// \param[in] port Port of the gcs server.
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
/// \param[in] gcs_service_failure_detected The function is used to redo subscription
Expand Down
9 changes: 7 additions & 2 deletions src/ray/rpc/gcs_server/gcs_rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ namespace rpc {
HANDLER, \
RayConfig::instance().gcs_max_active_rpcs_per_handler())

// TODO(vitsai): Set auth for everything except GCS.
#define INTERNAL_KV_SERVICE_RPC_HANDLER(HANDLER) \
RPC_SERVICE_HANDLER(InternalKVGcsService, HANDLER, -1)

Expand Down Expand Up @@ -382,7 +381,13 @@ class NodeInfoGrpcService : public GrpcService {
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
const ClusterID &cluster_id) override {
NODE_INFO_SERVICE_RPC_HANDLER(GetClusterId);
// We only allow one cluster ID in the lifetime of a client.
// So, if a client connects, it should not have a pre-existing different ID.
RPC_SERVICE_HANDLER_CUSTOM_AUTH(
NodeInfoGcsService,
GetClusterId,
RayConfig::instance().gcs_max_active_rpcs_per_handler(),
AuthType::EMPTY_AUTH);
NODE_INFO_SERVICE_RPC_HANDLER(RegisterNode);
NODE_INFO_SERVICE_RPC_HANDLER(DrainNode);
NODE_INFO_SERVICE_RPC_HANDLER(GetAllNodeInfo);
Expand Down
Loading

0 comments on commit 3f11cf4

Please sign in to comment.