Skip to content

Commit

Permalink
GCS server use worker table to handle RegisterWorker instead of redis…
Browse files Browse the repository at this point in the history
… accessor (ray-project#9168)
  • Loading branch information
kisuke95 authored Jul 6, 2020
1 parent dcf9892 commit 6f3d993
Show file tree
Hide file tree
Showing 40 changed files with 653 additions and 299 deletions.
5 changes: 5 additions & 0 deletions python/ray/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ResourceTableData,
ObjectLocationInfo,
PubSubMessage,
WorkerTableData,
)

__all__ = [
Expand All @@ -39,6 +40,7 @@
"construct_error_message",
"ObjectLocationInfo",
"PubSubMessage",
"WorkerTableData",
]

FUNCTION_PREFIX = "RemoteFunction:"
Expand Down Expand Up @@ -69,6 +71,9 @@
TablePrefix_JOB_string = "JOB"
TablePrefix_ACTOR_string = "ACTOR"

WORKER = 0
DRIVER = 1


def construct_error_message(job_id, error_type, message, timestamp):
"""Construct a serialized ErrorTableData object.
Expand Down
4 changes: 4 additions & 0 deletions python/ray/includes/global_state_accessor.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ from ray.includes.unique_ids cimport (
CActorID,
CClientID,
CObjectID,
CWorkerID,
)

cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
Expand All @@ -23,3 +24,6 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
c_vector[c_string] GetAllActorInfo()
unique_ptr[c_string] GetActorInfo(const CActorID &actor_id)
c_string GetNodeResourceInfo(const CClientID &node_id)
unique_ptr[c_string] GetWorkerInfo(const CWorkerID &worker_id)
c_vector[c_string] GetAllWorkerInfo()
c_bool AddWorkerInfo(const c_string &serialized_string)
13 changes: 13 additions & 0 deletions python/ray/includes/global_state_accessor.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from ray.includes.unique_ids cimport (
CActorID,
CClientID,
CObjectID,
CWorkerID,
)

from ray.includes.global_state_accessor cimport (
Expand Down Expand Up @@ -57,3 +58,15 @@ cdef class GlobalStateAccessor:

def get_node_resource_info(self, node_id):
return self.inner.get().GetNodeResourceInfo(CClientID.FromBinary(node_id.binary()))

def get_worker_table(self):
return self.inner.get().GetAllWorkerInfo()

def get_worker_info(self, worker_id):
worker_info = self.inner.get().GetWorkerInfo(CWorkerID.FromBinary(worker_id.binary()))
if worker_info:
return c_string(worker_info.get().data(), worker_info.get().size())
return None

def add_worker_info(self, serialized_string):
return self.inner.get().AddWorkerInfo(serialized_string)
58 changes: 42 additions & 16 deletions python/ray/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,25 +602,51 @@ def workers(self):
"""Get a dictionary mapping worker ID to worker information."""
self._check_connected()

worker_keys = self.redis_client.keys("Worker*")
# Get all data in worker table
worker_table = self.global_state_accessor.get_worker_table()
workers_data = {}
for i in range(len(worker_table)):
worker_table_data = gcs_utils.WorkerTableData.FromString(
worker_table[i])
if worker_table_data.is_alive and \
worker_table_data.worker_type == gcs_utils.WORKER:
worker_id = binary_to_hex(
worker_table_data.worker_address.worker_id)
worker_info = worker_table_data.worker_info

workers_data[worker_id] = {
"node_ip_address": decode(worker_info[b"node_ip_address"]),
"plasma_store_socket": decode(
worker_info[b"plasma_store_socket"])
}
if b"stderr_file" in worker_info:
workers_data[worker_id]["stderr_file"] = decode(
worker_info[b"stderr_file"])
if b"stdout_file" in worker_info:
workers_data[worker_id]["stdout_file"] = decode(
worker_info[b"stdout_file"])
return workers_data

for worker_key in worker_keys:
worker_info = self.redis_client.hgetall(worker_key)
worker_id = binary_to_hex(worker_key[len("Workers:"):])
def add_worker(self, worker_id, worker_type, worker_info):
""" Add a worker to the cluster.
workers_data[worker_id] = {
"node_ip_address": decode(worker_info[b"node_ip_address"]),
"plasma_store_socket": decode(
worker_info[b"plasma_store_socket"])
}
if b"stderr_file" in worker_info:
workers_data[worker_id]["stderr_file"] = decode(
worker_info[b"stderr_file"])
if b"stdout_file" in worker_info:
workers_data[worker_id]["stdout_file"] = decode(
worker_info[b"stdout_file"])
return workers_data
Args:
worker_id: ID of this worker. Type is bytes.
worker_type: Type of this worker. Value is ray.gcs_utils.DRIVER or
ray.gcs_utils.WORKER.
worker_info: Info of this worker. Type is dict{str: str}.
Returns:
Is operation success
"""
worker_data = ray.gcs_utils.WorkerTableData()
worker_data.is_alive = True
worker_data.worker_address.worker_id = worker_id
worker_data.worker_type = worker_type
for k, v in worker_info.items():
worker_data.worker_info[k] = bytes(v, encoding="utf-8")
return self.global_state_accessor.add_worker_info(
worker_data.SerializeToString())

def _job_length(self):
event_log_sets = self.redis_client.keys("event_log*")
Expand Down
13 changes: 6 additions & 7 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,15 +872,14 @@ def sigterm_handler(signum, frame):


def custom_excepthook(type, value, tb):
# If this is a driver, push the exception to redis.
# If this is a driver, push the exception to GCS worker table.
if global_worker.mode == SCRIPT_MODE:
error_message = "".join(traceback.format_tb(tb))
try:
global_worker.redis_client.hmset(
b"Drivers:" + global_worker.worker_id,
{"exception": error_message})
except (ConnectionRefusedError, redis.exceptions.ConnectionError):
logger.warning("Could not push exception to redis.")
worker_id = global_worker.worker_id
worker_type = ray.gcs_utils.DRIVER
worker_info = {"exception": error_message}

ray.state.state.add_worker(worker_id, worker_type, worker_info)
# Call the normal excepthook.
normal_excepthook(type, value, tb)

Expand Down
2 changes: 1 addition & 1 deletion src/ray/common/scheduling/cluster_resource_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ int64_t ClusterResourceScheduler::IsSchedulable(const TaskRequest &task_req,
}
}

// No check custom resources.
// Now check custom resources.
for (const auto task_req_custom_resource : task_req.custom_resources) {
auto it = resources.custom_resources.find(task_req_custom_resource.id);

Expand Down
2 changes: 1 addition & 1 deletion src/ray/common/scheduling/cluster_resource_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class TaskRequest {
public:
/// List of predefined resources required by the task.
std::vector<ResourceRequest> predefined_resources;
/// List of custom resources required by the tasl.
/// List of custom resources required by the task.
std::vector<ResourceRequestWithId> custom_resources;
/// List of placement hints. A placement hint is a node on which
/// we desire to run this task. This is a soft constraint in that
Expand Down
8 changes: 6 additions & 2 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,12 @@ void CoreWorker::RegisterToGcs() {
worker_info.emplace("stderr_file", options_.stderr_file);
}

RAY_CHECK_OK(gcs_client_->Workers().AsyncRegisterWorker(options_.worker_type, worker_id,
worker_info, nullptr));
auto worker_data = std::make_shared<rpc::WorkerTableData>();
worker_data->mutable_worker_address()->set_worker_id(worker_id.Binary());
worker_data->set_worker_type(options_.worker_type);
worker_data->mutable_worker_info()->insert(worker_info.begin(), worker_info.end());

RAY_CHECK_OK(gcs_client_->Workers().AsyncAdd(worker_data, nullptr));
}
void CoreWorker::CheckForRayletFailure(const boost::system::error_code &error) {
if (error == boost::asio::error::operation_aborted) {
Expand Down
34 changes: 23 additions & 11 deletions src/ray/gcs/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ class WorkerInfoAccessor {
/// \param done Callback that will be called when subscription is complete.
/// \return Status
virtual Status AsyncSubscribeToWorkerFailures(
const SubscribeCallback<WorkerID, rpc::WorkerFailureData> &subscribe,
const SubscribeCallback<WorkerID, rpc::WorkerTableData> &subscribe,
const StatusCallback &done) = 0;

/// Report a worker failure to GCS asynchronously.
Expand All @@ -656,19 +656,31 @@ class WorkerInfoAccessor {
/// \param callback Callback that will be called when report is complate.
/// \param Status
virtual Status AsyncReportWorkerFailure(
const std::shared_ptr<rpc::WorkerFailureData> &data_ptr,
const std::shared_ptr<rpc::WorkerTableData> &data_ptr,
const StatusCallback &callback) = 0;

/// Register a worker to GCS asynchronously.
/// Get worker specification from GCS asynchronously.
///
/// \param worker_type The type of the worker.
/// \param worker_id The ID of the worker.
/// \param worker_info The information of the worker.
/// \return Status.
virtual Status AsyncRegisterWorker(
rpc::WorkerType worker_type, const WorkerID &worker_id,
const std::unordered_map<std::string, std::string> &worker_info,
const StatusCallback &callback) = 0;
/// \param worker_id The ID of worker to look up in the GCS.
/// \param callback Callback that will be called after lookup finishes.
/// \return Status
virtual Status AsyncGet(const WorkerID &worker_id,
const OptionalItemCallback<rpc::WorkerTableData> &callback) = 0;

/// Get all worker info from GCS asynchronously.
///
/// \param callback Callback that will be called after lookup finished.
/// \return Status
virtual Status AsyncGetAll(const MultiItemCallback<rpc::WorkerTableData> &callback) = 0;

/// Add worker information to GCS asynchronously.
///
/// \param data_ptr The worker that will be add to GCS.
/// \param callback Callback that will be called after worker information has been added
/// to GCS.
/// \return Status
virtual Status AsyncAdd(const std::shared_ptr<rpc::WorkerTableData> &data_ptr,
const StatusCallback &callback) = 0;

/// Reestablish subscription.
/// This should be called when GCS server restarts from a failure.
Expand Down
33 changes: 33 additions & 0 deletions src/ray/gcs/gcs_client/global_state_accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,5 +178,38 @@ std::unique_ptr<std::string> GlobalStateAccessor::GetActorCheckpointId(
return actor_checkpoint_id_data;
}

std::unique_ptr<std::string> GlobalStateAccessor::GetWorkerInfo(
const WorkerID &worker_id) {
std::unique_ptr<std::string> worker_table_data;
std::promise<bool> promise;
RAY_CHECK_OK(gcs_client_->Workers().AsyncGet(
worker_id, TransformForOptionalItemCallback<rpc::WorkerTableData>(worker_table_data,
promise)));
promise.get_future().get();
return worker_table_data;
}

std::vector<std::string> GlobalStateAccessor::GetAllWorkerInfo() {
std::vector<std::string> worker_table_data;
std::promise<bool> promise;
RAY_CHECK_OK(gcs_client_->Workers().AsyncGetAll(
TransformForMultiItemCallback<rpc::WorkerTableData>(worker_table_data, promise)));
promise.get_future().get();
return worker_table_data;
}

bool GlobalStateAccessor::AddWorkerInfo(const std::string &serialized_string) {
auto data_ptr = std::make_shared<WorkerTableData>();
data_ptr->ParseFromString(serialized_string);
std::promise<bool> promise;
RAY_CHECK_OK(
gcs_client_->Workers().AsyncAdd(data_ptr, [&promise](const Status &status) {
RAY_CHECK_OK(status);
promise.set_value(true);
}));
promise.get_future().get();
return true;
}

} // namespace gcs
} // namespace ray
22 changes: 22 additions & 0 deletions src/ray/gcs/gcs_client/global_state_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,28 @@ class GlobalStateAccessor {
/// deserialized with protobuf function.
std::unique_ptr<std::string> GetActorCheckpointId(const ActorID &actor_id);

/// Get information of a worker from GCS Service.
///
/// \param worker_id The ID of worker to look up in the GCS Service.
/// \return Worker info. To support multi-language, we serialize each WorkerTableData
/// and return the serialized string. Where used, it needs to be deserialized with
/// protobuf function.
std::unique_ptr<std::string> GetWorkerInfo(const WorkerID &worker_id);

/// Get information of all workers from GCS Service.
///
/// \return All worker info. To support multi-language, we serialize each
/// WorkerTableData and return the serialized string. Where used, it needs to be
/// deserialized with protobuf function.
std::vector<std::string> GetAllWorkerInfo();

/// Add information of a worker to GCS Service.
///
/// \param serialized_string The serialized data of worker to be added in the GCS
/// Service, use string is convenient for python to use.
/// \return Is operation success.
bool AddWorkerInfo(const std::string &serialized_string);

private:
/// MultiItem transformation helper in template style.
///
Expand Down
57 changes: 41 additions & 16 deletions src/ray/gcs/gcs_client/service_based_accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1252,17 +1252,16 @@ ServiceBasedWorkerInfoAccessor::ServiceBasedWorkerInfoAccessor(
: client_impl_(client_impl) {}

Status ServiceBasedWorkerInfoAccessor::AsyncSubscribeToWorkerFailures(
const SubscribeCallback<WorkerID, rpc::WorkerFailureData> &subscribe,
const SubscribeCallback<WorkerID, rpc::WorkerTableData> &subscribe,
const StatusCallback &done) {
RAY_CHECK(subscribe != nullptr);
subscribe_operation_ = [this, subscribe](const StatusCallback &done) {
auto on_subscribe = [subscribe](const std::string &id, const std::string &data) {
rpc::WorkerFailureData worker_failure_data;
rpc::WorkerTableData worker_failure_data;
worker_failure_data.ParseFromString(data);
subscribe(WorkerID::FromBinary(id), worker_failure_data);
};
return client_impl_->GetGcsPubSub().SubscribeAll(WORKER_FAILURE_CHANNEL, on_subscribe,
done);
return client_impl_->GetGcsPubSub().SubscribeAll(WORKER_CHANNEL, on_subscribe, done);
};
return subscribe_operation_(done);
}
Expand All @@ -1276,7 +1275,7 @@ void ServiceBasedWorkerInfoAccessor::AsyncResubscribe(bool is_pubsub_server_rest
}

Status ServiceBasedWorkerInfoAccessor::AsyncReportWorkerFailure(
const std::shared_ptr<rpc::WorkerFailureData> &data_ptr,
const std::shared_ptr<rpc::WorkerTableData> &data_ptr,
const StatusCallback &callback) {
rpc::Address worker_address = data_ptr->worker_address();
RAY_LOG(DEBUG) << "Reporting worker failure, " << worker_address.DebugString();
Expand All @@ -1294,22 +1293,48 @@ Status ServiceBasedWorkerInfoAccessor::AsyncReportWorkerFailure(
return Status::OK();
}

Status ServiceBasedWorkerInfoAccessor::AsyncRegisterWorker(
rpc::WorkerType worker_type, const WorkerID &worker_id,
const std::unordered_map<std::string, std::string> &worker_info,
const StatusCallback &callback) {
RAY_LOG(DEBUG) << "Registering the worker. worker id = " << worker_id;
rpc::RegisterWorkerRequest request;
request.set_worker_type(worker_type);
Status ServiceBasedWorkerInfoAccessor::AsyncGet(
const WorkerID &worker_id,
const OptionalItemCallback<rpc::WorkerTableData> &callback) {
RAY_LOG(DEBUG) << "Getting worker info, worker id = " << worker_id;
rpc::GetWorkerInfoRequest request;
request.set_worker_id(worker_id.Binary());
request.mutable_worker_info()->insert(worker_info.begin(), worker_info.end());
client_impl_->GetGcsRpcClient().RegisterWorker(
client_impl_->GetGcsRpcClient().GetWorkerInfo(
request,
[worker_id, callback](const Status &status, const rpc::RegisterWorkerReply &reply) {
[worker_id, callback](const Status &status, const rpc::GetWorkerInfoReply &reply) {
if (reply.has_worker_table_data()) {
callback(status, reply.worker_table_data());
} else {
callback(status, boost::none);
}
RAY_LOG(DEBUG) << "Finished getting worker info, worker id = " << worker_id;
});
return Status::OK();
}

Status ServiceBasedWorkerInfoAccessor::AsyncGetAll(
const MultiItemCallback<rpc::WorkerTableData> &callback) {
RAY_LOG(DEBUG) << "Getting all worker info.";
rpc::GetAllWorkerInfoRequest request;
client_impl_->GetGcsRpcClient().GetAllWorkerInfo(
request, [callback](const Status &status, const rpc::GetAllWorkerInfoReply &reply) {
auto result = VectorFromProtobuf(reply.worker_table_data());
callback(status, result);
RAY_LOG(DEBUG) << "Finished getting all worker info, status = " << status;
});
return Status::OK();
}

Status ServiceBasedWorkerInfoAccessor::AsyncAdd(
const std::shared_ptr<rpc::WorkerTableData> &data_ptr,
const StatusCallback &callback) {
rpc::AddWorkerInfoRequest request;
request.mutable_worker_data()->CopyFrom(*data_ptr);
client_impl_->GetGcsRpcClient().AddWorkerInfo(
request, [callback](const Status &status, const rpc::AddWorkerInfoReply &reply) {
if (callback) {
callback(status);
}
RAY_LOG(DEBUG) << "Finished registering worker. worker id = " << worker_id;
});
return Status::OK();
}
Expand Down
Loading

0 comments on commit 6f3d993

Please sign in to comment.