diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 8b12a04fd9795..453f66bb40091 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -72,7 +72,7 @@ cdef c_vector[CObjectID] ObjectIDsToVector(object_ids): ObjectID object_id c_vector[CObjectID] result for object_id in object_ids: - result.push_back(object_id.data) + result.push_back(object_id.native()) return result @@ -87,11 +87,11 @@ def compute_put_id(TaskID task_id, int64_t put_index): if put_index < 1 or put_index > kMaxTaskPuts: raise ValueError("The range of 'put_index' should be [1, %d]" % kMaxTaskPuts) - return ObjectID(ComputePutId(task_id.data, put_index).binary()) + return ObjectID(ComputePutId(task_id.native(), put_index).binary()) def compute_task_id(ObjectID object_id): - return TaskID(ComputeTaskId(object_id.data).binary()) + return TaskID(ComputeTaskId(object_id.native()).binary()) cdef c_bool is_simple_value(value, int *num_elements_contained): @@ -225,8 +225,8 @@ cdef class RayletClient: # parameter. # TODO(suquark): Should we allow unicode chars in "raylet_socket"? self.client.reset(new CRayletClient( - raylet_socket.encode("ascii"), client_id.data, is_worker, - driver_id.data, LANGUAGE_PYTHON)) + raylet_socket.encode("ascii"), client_id.native(), is_worker, + driver_id.native(), LANGUAGE_PYTHON)) def disconnect(self): check_status(self.client.get().Disconnect()) @@ -252,22 +252,23 @@ cdef class RayletClient: TaskID current_task_id=TaskID.nil()): cdef c_vector[CObjectID] fetch_ids = ObjectIDsToVector(object_ids) check_status(self.client.get().FetchOrReconstruct( - fetch_ids, fetch_only, current_task_id.data)) + fetch_ids, fetch_only, current_task_id.native())) def notify_unblocked(self, TaskID current_task_id): - check_status(self.client.get().NotifyUnblocked(current_task_id.data)) + check_status(self.client.get().NotifyUnblocked(current_task_id.native())) def wait(self, object_ids, int num_returns, int64_t timeout_milliseconds, c_bool wait_local, TaskID current_task_id): cdef: WaitResultPair result c_vector[CObjectID] wait_ids + CTaskID c_task_id = current_task_id.native() wait_ids = ObjectIDsToVector(object_ids) with nogil: check_status(self.client.get().Wait(wait_ids, num_returns, timeout_milliseconds, wait_local, - current_task_id.data, &result)) + c_task_id, &result)) return (VectorToObjectIDs(result.first), VectorToObjectIDs(result.second)) @@ -291,9 +292,9 @@ cdef class RayletClient: postincrement(iterator) return resources_dict - def push_error(self, DriverID job_id, error_type, error_message, + def push_error(self, DriverID driver_id, error_type, error_message, double timestamp): - check_status(self.client.get().PushError(job_id.data, + check_status(self.client.get().PushError(driver_id.native(), error_type.encode("ascii"), error_message.encode("ascii"), timestamp)) @@ -354,7 +355,7 @@ cdef class RayletClient: def prepare_actor_checkpoint(self, ActorID actor_id): cdef CActorCheckpointID checkpoint_id - cdef CActorID c_actor_id = actor_id.data + cdef CActorID c_actor_id = actor_id.native() # PrepareActorCheckpoint will wait for raylet's reply, release # the GIL so other Python threads can run. with nogil: @@ -365,7 +366,7 @@ cdef class RayletClient: def notify_actor_resumed_from_checkpoint(self, ActorID actor_id, ActorCheckpointID checkpoint_id): check_status(self.client.get().NotifyActorResumedFromCheckpoint( - actor_id.data, checkpoint_id.data)) + actor_id.native(), checkpoint_id.native())) @property def language(self): diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index 1a4ffb2502357..a496c5b837835 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -62,7 +62,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: int num_returns, int64_t timeout_milliseconds, c_bool wait_local, const CTaskID ¤t_task_id, WaitResultPair *result) - CRayStatus PushError(const CDriverID &job_id, const c_string &type, + CRayStatus PushError(const CDriverID &driver_id, const c_string &type, const c_string &error_message, double timestamp) CRayStatus PushProfileEvents( const GCSProfileTableDataT &profile_events) diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index a7cfc684b9d03..872b93d222693 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -54,7 +54,7 @@ cdef class Task: for arg in arguments: if isinstance(arg, ObjectID): references = c_vector[CObjectID]() - references.push_back((arg).data) + references.push_back((arg).native()) task_args.push_back( static_pointer_cast[CTaskArgument, CTaskArgumentByReference]( @@ -71,23 +71,21 @@ cdef class Task: for new_actor_handle in new_actor_handles: task_new_actor_handles.push_back( - (new_actor_handle).data) + (new_actor_handle).native()) self.task_spec.reset(new CTaskSpecification( - CUniqueID(driver_id.data), parent_task_id.data, parent_counter, - actor_creation_id.data, actor_creation_dummy_object_id.data, - max_actor_reconstructions, CUniqueID(actor_id.data), - CUniqueID(actor_handle_id.data), actor_counter, - task_new_actor_handles, task_args, num_returns, - required_resources, required_placement_resources, - LANGUAGE_PYTHON, c_function_descriptor)) + driver_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(), + actor_creation_dummy_object_id.native(), max_actor_reconstructions, actor_id.native(), + actor_handle_id.native(), actor_counter, task_new_actor_handles, task_args, num_returns, + required_resources, required_placement_resources, LANGUAGE_PYTHON, + c_function_descriptor)) # Set the task's execution dependencies. self.execution_dependencies.reset(new c_vector[CObjectID]()) if execution_arguments is not None: for execution_arg in execution_arguments: self.execution_dependencies.get().push_back( - (execution_arg).data) + (execution_arg).native()) @staticmethod cdef make(unique_ptr[CTaskSpecification]& task_spec): diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index fc36f97766c19..cadbdfea2827d 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -5,13 +5,14 @@ from libc.stdint cimport uint8_t cdef extern from "ray/id.h" namespace "ray" nogil: cdef cppclass CUniqueID "ray::UniqueID": CUniqueID() + CUniqueID(const c_string &binary) CUniqueID(const CUniqueID &from_id) @staticmethod CUniqueID from_random() @staticmethod - CUniqueID from_binary(const c_string & binary) + CUniqueID from_binary(const c_string &binary) @staticmethod const CUniqueID nil() @@ -26,14 +27,73 @@ cdef extern from "ray/id.h" namespace "ray" nogil: c_string binary() const c_string hex() const -ctypedef CUniqueID CActorCheckpointID -ctypedef CUniqueID CActorClassID -ctypedef CUniqueID CActorHandleID -ctypedef CUniqueID CActorID -ctypedef CUniqueID CClientID -ctypedef CUniqueID CConfigID -ctypedef CUniqueID CDriverID -ctypedef CUniqueID CFunctionID -ctypedef CUniqueID CObjectID -ctypedef CUniqueID CTaskID -ctypedef CUniqueID CWorkerID + cdef cppclass CActorCheckpointID "ray::ActorCheckpointID"(CUniqueID): + + @staticmethod + CActorCheckpointID from_binary(const c_string &binary) + + + cdef cppclass CActorClassID "ray::ActorClassID"(CUniqueID): + + @staticmethod + CActorClassID from_binary(const c_string &binary) + + + cdef cppclass CActorID "ray::ActorID"(CUniqueID): + + @staticmethod + CActorID from_binary(const c_string &binary) + + + cdef cppclass CActorHandleID "ray::ActorHandleID"(CUniqueID): + + @staticmethod + CActorHandleID from_binary(const c_string &binary) + + + cdef cppclass CClientID "ray::ClientID"(CUniqueID): + + @staticmethod + CClientID from_binary(const c_string &binary) + + + cdef cppclass CConfigID "ray::ConfigID"(CUniqueID): + + @staticmethod + CConfigID from_binary(const c_string &binary) + + + cdef cppclass CFunctionID "ray::FunctionID"(CUniqueID): + + @staticmethod + CFunctionID from_binary(const c_string &binary) + + + cdef cppclass CDriverID "ray::DriverID"(CUniqueID): + + @staticmethod + CDriverID from_binary(const c_string &binary) + + + cdef cppclass CJobID "ray::JobID"(CUniqueID): + + @staticmethod + CJobID from_binary(const c_string &binary) + + + cdef cppclass CTaskID "ray::TaskID"(CUniqueID): + + @staticmethod + CTaskID from_binary(const c_string &binary) + + + cdef cppclass CObjectID" ray::ObjectID"(CUniqueID): + + @staticmethod + CObjectID from_binary(const c_string &binary) + + + cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID): + + @staticmethod + CWorkerID from_binary(const c_string &binary) diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 670579737d7cc..0086f76b51b01 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -19,6 +19,7 @@ from ray.includes.unique_ids cimport ( CConfigID, CDriverID, CFunctionID, + CJobID, CObjectID, CTaskID, CUniqueID, @@ -45,11 +46,8 @@ cdef class UniqueID: cdef CUniqueID data def __init__(self, id): - if not id: - self.data = CUniqueID() - else: - check_id(id) - self.data = CUniqueID.from_binary(id) + check_id(id) + self.data = CUniqueID.from_binary(id) @classmethod def from_binary(cls, id_bytes): @@ -59,7 +57,7 @@ cdef class UniqueID: @classmethod def nil(cls): - return cls(b"") + return cls(CUniqueID.nil().binary()) def __hash__(self): return self.data.hash() @@ -106,40 +104,93 @@ cdef class UniqueID: cdef class ObjectID(UniqueID): - pass + + def __init__(self, id): + check_id(id) + self.data = CObjectID.from_binary(id) + + cdef CObjectID native(self): + return self.data cdef class TaskID(UniqueID): - pass + + def __init__(self, id): + check_id(id) + self.data = CTaskID.from_binary(id) + + cdef CTaskID native(self): + return self.data cdef class ClientID(UniqueID): - pass + + def __init__(self, id): + check_id(id) + self.data = CClientID.from_binary(id) + + cdef CClientID native(self): + return self.data cdef class DriverID(UniqueID): - pass + + def __init__(self, id): + check_id(id) + self.data = CDriverID.from_binary(id) + + cdef CDriverID native(self): + return self.data cdef class ActorID(UniqueID): - pass + + def __init__(self, id): + check_id(id) + self.data = CActorID.from_binary(id) + + cdef CActorID native(self): + return self.data cdef class ActorHandleID(UniqueID): - pass + + def __init__(self, id): + check_id(id) + self.data = CActorHandleID.from_binary(id) + + cdef CActorHandleID native(self): + return self.data cdef class ActorCheckpointID(UniqueID): - pass + + def __init__(self, id): + check_id(id) + self.data = CActorCheckpointID.from_binary(id) + + cdef CActorCheckpointID native(self): + return self.data cdef class FunctionID(UniqueID): - pass + + def __init__(self, id): + check_id(id) + self.data = CFunctionID.from_binary(id) + + cdef CFunctionID native(self): + return self.data cdef class ActorClassID(UniqueID): - pass + def __init__(self, id): + check_id(id) + self.data = CActorClassID.from_binary(id) + + cdef CActorClassID native(self): + return self.data _ID_TYPES = [ ActorCheckpointID, diff --git a/src/ray/common/common_protocol.cc b/src/ray/common/common_protocol.cc index f5ed40af570cb..adce684fc299f 100644 --- a/src/ray/common/common_protocol.cc +++ b/src/ray/common/common_protocol.cc @@ -2,74 +2,6 @@ #include "ray/util/logging.h" -flatbuffers::Offset to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - ray::ObjectID object_id) { - return fbb.CreateString(reinterpret_cast(object_id.data()), - sizeof(ray::ObjectID)); -} - -ray::ObjectID from_flatbuf(const flatbuffers::String &string) { - ray::ObjectID object_id; - RAY_CHECK(string.size() == sizeof(ray::ObjectID)); - memcpy(object_id.mutable_data(), string.data(), sizeof(ray::ObjectID)); - return object_id; -} - -const std::vector from_flatbuf( - const flatbuffers::Vector> &vector) { - std::vector object_ids; - for (int64_t i = 0; i < vector.Length(); i++) { - object_ids.push_back(from_flatbuf(*vector.Get(i))); - } - return object_ids; -} - -const std::vector object_ids_from_flatbuf( - const flatbuffers::String &string) { - const auto &object_ids = string_from_flatbuf(string); - std::vector ret; - RAY_CHECK(object_ids.size() % kUniqueIDSize == 0); - auto count = object_ids.size() / kUniqueIDSize; - - for (size_t i = 0; i < count; ++i) { - auto pos = static_cast(kUniqueIDSize * i); - const auto &id = object_ids.substr(pos, kUniqueIDSize); - ret.push_back(ray::ObjectID::from_binary(id)); - } - - return ret; -} - -flatbuffers::Offset object_ids_to_flatbuf( - flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids) { - std::string result; - for (const auto &id : object_ids) { - result += id.binary(); - } - - return fbb.CreateString(result); -} - -flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ray::ObjectID object_ids[], - int64_t num_objects) { - std::vector> results; - for (int64_t i = 0; i < num_objects; i++) { - results.push_back(to_flatbuf(fbb, object_ids[i])); - } - return fbb.CreateVector(results); -} - -flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - const std::vector &object_ids) { - std::vector> results; - for (auto object_id : object_ids) { - results.push_back(to_flatbuf(fbb, object_id)); - } - return fbb.CreateVector(results); -} - std::string string_from_flatbuf(const flatbuffers::String &string) { return std::string(string.data(), string.size()); } diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index bea4a5b925424..bc3d9b646a4b6 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -6,63 +6,68 @@ #include #include "ray/id.h" +#include "ray/util/logging.h" -/// Convert an object ID to a flatbuffer string. +/// Convert an unique ID to a flatbuffer string. /// /// @param fbb Reference to the flatbuffer builder. -/// @param object_id The object ID to be converted. -/// @return The flatbuffer string contining the object ID. +/// @param id The ID to be converted. +/// @return The flatbuffer string containing the ID. +template flatbuffers::Offset to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - ray::ObjectID object_id); + ID id); -/// Convert a flatbuffer string to an object ID. +/// Convert a flatbuffer string to an unique ID. /// /// @param string The flatbuffer string. -/// @return The object ID. -ray::ObjectID from_flatbuf(const flatbuffers::String &string); +/// @return The ID. +template +ID from_flatbuf(const flatbuffers::String &string); -/// Convert a flatbuffer vector of strings to a vector of object IDs. +/// Convert a flatbuffer vector of strings to a vector of unique IDs. /// /// @param vector The flatbuffer vector. -/// @return The vector of object IDs. -const std::vector from_flatbuf( +/// @return The vector of IDs. +template +const std::vector from_flatbuf( const flatbuffers::Vector> &vector); /// Convert a flatbuffer of string that concatenated -/// object IDs to a vector of object IDs. +/// unique IDs to a vector of unique IDs. /// /// @param vector The flatbuffer vector. -/// @return The vector of object IDs. -const std::vector object_ids_from_flatbuf( - const flatbuffers::String &string); +/// @return The vector of IDs. +template +const std::vector ids_from_flatbuf(const flatbuffers::String &string); -/// Convert a vector of object IDs to a flatbuffer string. +/// Convert a vector of unique IDs to a flatbuffer string. /// The IDs are concatenated to a string with binary. /// /// @param fbb Reference to the flatbuffer builder. -/// @param object_ids The vector of object IDs. +/// @param ids The vector of IDs. /// @return Flatbuffer string of concatenated IDs. -flatbuffers::Offset object_ids_to_flatbuf( - flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids); +template +flatbuffers::Offset ids_to_flatbuf( + flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids); -/// Convert an array of object IDs to a flatbuffer vector of strings. +/// Convert an array of unique IDs to a flatbuffer vector of strings. /// /// @param fbb Reference to the flatbuffer builder. -/// @param object_ids Array of object IDs. -/// @param num_objects Number of elements in the array. +/// @param ids Array of unique IDs. +/// @param num_ids Number of elements in the array. /// @return Flatbuffer vector of strings. +template flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ray::ObjectID object_ids[], - int64_t num_objects); +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID ids[], int64_t num_ids); -/// Convert a vector of object IDs to a flatbuffer vector of strings. +/// Convert a vector of unique IDs to a flatbuffer vector of strings. /// /// @param fbb Reference to the flatbuffer builder. -/// @param object_ids Vector of object IDs. +/// @param ids Vector of IDs. /// @return Flatbuffer vector of strings. +template flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - const std::vector &object_ids); +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids); /// Convert a flatbuffer string to a std::string. /// @@ -95,4 +100,76 @@ std::vector string_vec_from_flatbuf( flatbuffers::Offset>> string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &string_vector); + +template +flatbuffers::Offset to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, + ID id) { + return fbb.CreateString(reinterpret_cast(id.data()), sizeof(ID)); +} + +template +ID from_flatbuf(const flatbuffers::String &string) { + ID id; + RAY_CHECK(string.size() == sizeof(ID)); + memcpy(id.mutable_data(), string.data(), sizeof(ID)); + return id; +} + +template +const std::vector from_flatbuf( + const flatbuffers::Vector> &vector) { + std::vector ids; + for (int64_t i = 0; i < vector.Length(); i++) { + ids.push_back(from_flatbuf(*vector.Get(i))); + } + return ids; +} + +template +const std::vector ids_from_flatbuf(const flatbuffers::String &string) { + const auto &ids = string_from_flatbuf(string); + std::vector ret; + RAY_CHECK(ids.size() % kUniqueIDSize == 0); + auto count = ids.size() / kUniqueIDSize; + + for (size_t i = 0; i < count; ++i) { + auto pos = static_cast(kUniqueIDSize * i); + const auto &id = ids.substr(pos, kUniqueIDSize); + ret.push_back(ID::from_binary(id)); + } + + return ret; +} + +template +flatbuffers::Offset ids_to_flatbuf( + flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids) { + std::string result; + for (const auto &id : ids) { + result += id.binary(); + } + + return fbb.CreateString(result); +} + +template +flatbuffers::Offset>> +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID ids[], int64_t num_ids) { + std::vector> results; + for (int64_t i = 0; i < num_ids; i++) { + results.push_back(to_flatbuf(fbb, ids[i])); + } + return fbb.CreateVector(results); +} + +template +flatbuffers::Offset>> +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids) { + std::vector> results; + for (auto id : ids) { + results.push_back(to_flatbuf(fbb, id)); + } + return fbb.CreateVector(results); +} + #endif diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 6bf2a53156bef..b7aab1582ef02 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -814,7 +814,7 @@ void TestClientTableConnect(const JobID &job_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, true); test->Stop(); }); @@ -839,14 +839,14 @@ void TestClientTableDisconnect(const JobID &job_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, /*is_insertion=*/true); // Disconnect from the client table. We should receive a notification // for the removal of our own entry. RAY_CHECK_OK(client->client_table().Disconnect()); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, /*is_insertion=*/false); test->Stop(); }); @@ -870,11 +870,11 @@ void TestClientTableImmediateDisconnect(const JobID &job_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, true); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, false); test->Stop(); }); diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 87a72258ba2af..8e60c3a0d96fa 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -91,7 +91,7 @@ Status Log::Lookup(const JobID &job_id, const ID &id, const Callback & std::vector results; if (!data.empty()) { auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); + RAY_CHECK(from_flatbuf(*root->id()) == id); for (size_t i = 0; i < root->entries()->size(); i++) { DataT result; auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); @@ -128,7 +128,7 @@ Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, auto root = flatbuffers::GetRoot(data.data()); ID id; if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); + id = from_flatbuf(*root->id()); } std::vector results; for (size_t i = 0; i < root->entries()->size(); i++) { @@ -274,18 +274,18 @@ std::string Table::DebugString() const { return result.str(); } -Status ErrorTable::PushErrorToDriver(const JobID &job_id, const std::string &type, +Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { auto data = std::make_shared(); - data->job_id = job_id.binary(); + data->job_id = driver_id.binary(); data->type = type; data->error_message = error_message; data->timestamp = timestamp; - return Append(job_id, job_id, data, /*done_callback=*/nullptr); + return Append(JobID(driver_id), driver_id, data, /*done_callback=*/nullptr); } std::string ErrorTable::DebugString() const { - return Log::DebugString(); + return Log::DebugString(); } Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { @@ -302,11 +302,11 @@ std::string ProfileTable::DebugString() const { return Log::DebugString(); } -Status DriverTable::AppendDriverData(const JobID &driver_id, bool is_dead) { +Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { auto data = std::make_shared(); data->driver_id = driver_id.binary(); data->is_dead = is_dead; - return Append(driver_id, driver_id, data, /*done_callback=*/nullptr); + return Append(JobID(driver_id), driver_id, data, /*done_callback=*/nullptr); } void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) { @@ -492,7 +492,7 @@ Status ClientTable::Lookup(const Callback &lookup) { std::string ClientTable::DebugString() const { std::stringstream result; - result << Log::DebugString(); + result << Log::DebugString(); result << ", cache size: " << client_cache_.size() << ", num removed: " << removed_clients_.size(); return result.str(); @@ -500,7 +500,7 @@ std::string ClientTable::DebugString() const { Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, const ActorID &actor_id, - const UniqueID &checkpoint_id) { + const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, job_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id, const ActorCheckpointIdDataT &data) { @@ -512,7 +512,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, while (copy->timestamps.size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. const auto &checkpoint_id = - UniqueID::from_binary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); + ActorCheckpointID::from_binary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor " << actor_id; copy->timestamps.erase(copy->timestamps.begin()); @@ -542,9 +542,9 @@ template class Log; template class Table; template class Table; template class Table; -template class Log; -template class Log; -template class Log; +template class Log; +template class Log; +template class Log; template class Log; template class Table; template class Table; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 71e1c39d6da7d..2aabf2ae3e017 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -382,7 +382,7 @@ class HeartbeatBatchTable : public Table { virtual ~HeartbeatBatchTable() {} }; -class DriverTable : public Log { +class DriverTable : public Log { public: DriverTable(const std::vector> &contexts, AsyncGcsClient *client) @@ -398,7 +398,7 @@ class DriverTable : public Log { /// \param driver_id The driver id. /// \param is_dead Whether the driver is dead. /// \return The return status. - Status AppendDriverData(const JobID &driver_id, bool is_dead); + Status AppendDriverData(const DriverID &driver_id, bool is_dead); }; class FunctionTable : public Table { @@ -488,7 +488,7 @@ class ActorCheckpointIdTable : public Table { /// \param checkpoint_id ID of the checkpoint. /// \return Status. Status AddCheckpointId(const JobID &job_id, const ActorID &actor_id, - const UniqueID &checkpoint_id); + const ActorCheckpointID &checkpoint_id); }; namespace raylet { @@ -511,7 +511,7 @@ class TaskTable : public Table { } // namespace raylet -class ErrorTable : private Log { +class ErrorTable : private Log { public: ErrorTable(const std::vector> &contexts, AsyncGcsClient *client) @@ -532,7 +532,7 @@ class ErrorTable : private Log { /// \param error_message The error message to push. /// \param timestamp The timestamp of the error. /// \return Status. - Status PushErrorToDriver(const JobID &job_id, const std::string &type, + Status PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp); /// Returns debug string for class. @@ -574,7 +574,7 @@ using ConfigTable = Table; /// it should append an entry to the log indicating that it is dead. A client /// that is marked as dead should never again be marked as alive; if it needs /// to reconnect, it must connect with a different ClientID. -class ClientTable : private Log { +class ClientTable : private Log { public: using ClientTableCallback = std::function; @@ -678,7 +678,7 @@ class ClientTable : private Log { /// The key at which the log of client information is stored. This key must /// be kept the same across all instances of the ClientTable, so that all /// clients append and read from the same key. - UniqueID client_log_key_; + ClientID client_log_key_; /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. diff --git a/src/ray/id.cc b/src/ray/id.cc index 70454bbdfb0d6..a9d9c5a7e7652 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -165,7 +165,7 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id) { const ObjectID ComputeObjectId(const TaskID &task_id, int64_t object_index) { RAY_CHECK(object_index <= kMaxTaskReturns && object_index >= -kMaxTaskPuts); - ObjectID return_id = task_id; + ObjectID return_id = ObjectID(task_id); int64_t *first_bytes = reinterpret_cast(&return_id); // Zero out the lowest kObjectIdIndexSize bits of the first byte of the // object ID. @@ -176,7 +176,9 @@ const ObjectID ComputeObjectId(const TaskID &task_id, int64_t object_index) { return return_id; } -const TaskID FinishTaskId(const TaskID &task_id) { return ComputeObjectId(task_id, 0); } +const TaskID FinishTaskId(const TaskID &task_id) { + return TaskID(ComputeObjectId(task_id, 0)); +} const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index) { RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns); @@ -190,7 +192,7 @@ const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index) { } const TaskID ComputeTaskId(const ObjectID &object_id) { - TaskID task_id = object_id; + TaskID task_id = TaskID(object_id); int64_t *first_bytes = reinterpret_cast(&task_id); // Zero out the lowest kObjectIdIndexSize bits of the first byte of the // object ID. diff --git a/src/ray/id.h b/src/ray/id.h index 562365951fc28..35c67b220faf1 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -30,7 +30,7 @@ class RAY_EXPORT UniqueID { std::string hex() const; plasma::UniqueID to_plasma_id() const; - private: + protected: uint8_t id_[kUniqueIDSize]; }; @@ -38,18 +38,24 @@ static_assert(std::is_standard_layout::value, "UniqueID must be standa std::ostream &operator<<(std::ostream &os, const UniqueID &id); -typedef UniqueID TaskID; -typedef UniqueID JobID; -typedef UniqueID ObjectID; -typedef UniqueID FunctionID; -typedef UniqueID ActorClassID; -typedef UniqueID ActorID; -typedef UniqueID ActorHandleID; -typedef UniqueID ActorCheckpointID; -typedef UniqueID WorkerID; -typedef UniqueID DriverID; -typedef UniqueID ConfigID; -typedef UniqueID ClientID; +#define DEFINE_UNIQUE_ID(type) \ + class RAY_EXPORT type : public UniqueID { \ + public: \ + explicit type(const UniqueID &from) { \ + std::memcpy(&id_, from.data(), kUniqueIDSize); \ + } \ + type() : UniqueID() {} \ + static type from_random() { return type(UniqueID::from_random()); } \ + static type from_binary(const std::string &binary) { return type(binary); } \ + static type nil() { return type(UniqueID::nil()); } \ + \ + private: \ + type(const std::string &binary) { std::memcpy(id_, binary.data(), kUniqueIDSize); } \ + }; + +#include "id_def.h" + +#undef DEFINE_UNIQUE_ID // TODO(swang): ObjectID and TaskID should derive from UniqueID. Then, we // can make these methods of the derived classes. @@ -101,14 +107,20 @@ int64_t ComputeObjectIndex(const ObjectID &object_id); } // namespace ray namespace std { -template <> -struct hash<::ray::UniqueID> { - size_t operator()(const ::ray::UniqueID &id) const { return id.hash(); } -}; -template <> -struct hash { - size_t operator()(const ::ray::UniqueID &id) const { return id.hash(); } -}; -} +#define DEFINE_UNIQUE_ID(type) \ + template <> \ + struct hash<::ray::type> { \ + size_t operator()(const ::ray::type &id) const { return id.hash(); } \ + }; \ + template <> \ + struct hash { \ + size_t operator()(const ::ray::type &id) const { return id.hash(); } \ + }; + +DEFINE_UNIQUE_ID(UniqueID); +#include "id_def.h" + +#undef DEFINE_UNIQUE_ID +} // namespace std #endif // RAY_ID_H_ diff --git a/src/ray/id_def.h b/src/ray/id_def.h new file mode 100644 index 0000000000000..8e8b2b3fb717d --- /dev/null +++ b/src/ray/id_def.h @@ -0,0 +1,18 @@ +// This header file is used to avoid code duplication. +// It can be included multiple times in id.h, and each inclusion +// could use a different definition of the DEFINE_UNIQUE_ID macro. +// Macro definition format: DEFINE_UNIQUE_ID(id_type). +// NOTE: This file should NOT be included in any file other than id.h. + +DEFINE_UNIQUE_ID(TaskID); +DEFINE_UNIQUE_ID(JobID); +DEFINE_UNIQUE_ID(ObjectID); +DEFINE_UNIQUE_ID(FunctionID); +DEFINE_UNIQUE_ID(ActorClassID); +DEFINE_UNIQUE_ID(ActorID); +DEFINE_UNIQUE_ID(ActorHandleID); +DEFINE_UNIQUE_ID(ActorCheckpointID); +DEFINE_UNIQUE_ID(WorkerID); +DEFINE_UNIQUE_ID(DriverID); +DEFINE_UNIQUE_ID(ConfigID); +DEFINE_UNIQUE_ID(ClientID); diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 51cb2600beb3f..d9f7b87a700e9 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -78,7 +78,7 @@ void ObjectDirectory::RegisterBackend() { } }; RAY_CHECK_OK(gcs_client_->object_table().Subscribe( - UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), + JobID::nil(), gcs_client_->client_table().GetLocalClientId(), object_notification_callback, nullptr)); } diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 5459985e5b61f..7c949be311a6b 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -767,7 +767,7 @@ void ObjectManager::ConnectClient(std::shared_ptr &conn, // TODO: trash connection on failure. auto info = flatbuffers::GetRoot(message); - ClientID client_id = ObjectID::from_binary(info->client_id()->str()); + ClientID client_id = ClientID::from_binary(info->client_id()->str()); bool is_transfer = info->is_transfer(); conn->SetClientID(client_id); if (is_transfer) { @@ -885,7 +885,7 @@ void ObjectManager::ReceiveFreeRequest(std::shared_ptr &con const uint8_t *message) { auto free_request = flatbuffers::GetRoot(message); - std::vector object_ids = from_flatbuf(*free_request->object_ids()); + std::vector object_ids = from_flatbuf(*free_request->object_ids()); // This RPC should come from another Object Manager. // Keep this request local. bool local_only = true; diff --git a/src/ray/object_manager/object_store_notification_manager.cc b/src/ray/object_manager/object_store_notification_manager.cc index aa19787f3c37f..746f4d622d5af 100644 --- a/src/ray/object_manager/object_store_notification_manager.cc +++ b/src/ray/object_manager/object_store_notification_manager.cc @@ -58,7 +58,7 @@ void ObjectStoreNotificationManager::ProcessStoreNotification( const auto &object_info = flatbuffers::GetRoot(notification_.data()); - const auto &object_id = from_flatbuf(*object_info->object_id()); + const auto &object_id = from_flatbuf(*object_info->object_id()); if (object_info->is_deletion()) { ProcessStoreRemove(object_id); } else { diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 710928cdbd880..20bb1c735c1c5 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -196,7 +196,7 @@ table WaitReply { // This struct is the same as ErrorTableData. table PushErrorRequest { // The ID of the job that the error is for. - job_id: string; + driver_id: string; // The type of the error. type: string; // The error message. diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 68004a37bf218..c55b2608b2fda 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -6,31 +6,30 @@ #include "ray/raylet/raylet_client.h" #include "ray/util/logging.h" -#ifdef __cplusplus -extern "C" { -#endif - +template class UniqueIdFromJByteArray { - private: - JNIEnv *_env; - jbyteArray _bytes; - public: - UniqueID *PID; + const ID &GetId() const { return *id_pointer_; } - UniqueIdFromJByteArray(JNIEnv *env, jbyteArray wid) { - _env = env; - _bytes = wid; - - jbyte *b = reinterpret_cast(_env->GetByteArrayElements(_bytes, nullptr)); - PID = reinterpret_cast(b); + UniqueIdFromJByteArray(JNIEnv *env, jbyteArray bytes) : env_(env), bytes_(bytes) { + jbyte *b = reinterpret_cast(env_->GetByteArrayElements(bytes_, nullptr)); + id_pointer_ = reinterpret_cast(b); } ~UniqueIdFromJByteArray() { - _env->ReleaseByteArrayElements(_bytes, reinterpret_cast(PID), 0); + env_->ReleaseByteArrayElements(bytes_, reinterpret_cast(id_pointer_), 0); } + + private: + JNIEnv *env_; + jbyteArray bytes_; + ID *id_pointer_; }; +#ifdef __cplusplus +extern "C" { +#endif + inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) { if (!status.ok()) { jclass exception_class = env->FindClass("org/ray/api/exception/RayException"); @@ -49,11 +48,11 @@ inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) { JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker, jbyteArray driverId) { - UniqueIdFromJByteArray worker_id(env, workerId); - UniqueIdFromJByteArray driver_id(env, driverId); + UniqueIdFromJByteArray worker_id(env, workerId); + UniqueIdFromJByteArray driver_id(env, driverId); const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); - auto raylet_client = new RayletClient(nativeString, *worker_id.PID, isWorker, - *driver_id.PID, Language::JAVA); + auto raylet_client = new RayletClient(nativeString, worker_id.GetId(), isWorker, + driver_id.GetId(), Language::JAVA); env->ReleaseStringUTFChars(sockName, nativeString); return reinterpret_cast(raylet_client); } @@ -70,8 +69,8 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit std::vector execution_dependencies; if (cursorId != nullptr) { - UniqueIdFromJByteArray cursor_id(env, cursorId); - execution_dependencies.push_back(*cursor_id.PID); + UniqueIdFromJByteArray cursor_id(env, cursorId); + execution_dependencies.push_back(cursor_id.GetId()); } auto data = reinterpret_cast(env->GetDirectBufferAddress(taskBuff)) + pos; @@ -143,14 +142,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( for (int i = 0; i < len; i++) { jbyteArray object_id_bytes = static_cast(env->GetObjectArrayElement(objectIds, i)); - UniqueIdFromJByteArray object_id(env, object_id_bytes); - object_ids.push_back(*object_id.PID); + UniqueIdFromJByteArray object_id(env, object_id_bytes); + object_ids.push_back(object_id.GetId()); env->DeleteLocalRef(object_id_bytes); } - UniqueIdFromJByteArray current_task_id(env, currentTaskId); + UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto raylet_client = reinterpret_cast(client); auto status = - raylet_client->FetchOrReconstruct(object_ids, fetchOnly, *current_task_id.PID); + raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id.GetId()); ThrowRayExceptionIfNotOK(env, status); } @@ -161,9 +160,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) { - UniqueIdFromJByteArray current_task_id(env, currentTaskId); + UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto raylet_client = reinterpret_cast(client); - auto status = raylet_client->NotifyUnblocked(*current_task_id.PID); + auto status = raylet_client->NotifyUnblocked(current_task_id.GetId()); ThrowRayExceptionIfNotOK(env, status); } @@ -181,19 +180,19 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( for (int i = 0; i < len; i++) { jbyteArray object_id_bytes = static_cast(env->GetObjectArrayElement(objectIds, i)); - UniqueIdFromJByteArray object_id(env, object_id_bytes); - object_ids.push_back(*object_id.PID); + UniqueIdFromJByteArray object_id(env, object_id_bytes); + object_ids.push_back(object_id.GetId()); env->DeleteLocalRef(object_id_bytes); } - UniqueIdFromJByteArray current_task_id(env, currentTaskId); + UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto raylet_client = reinterpret_cast(client); // Invoke wait. WaitResultPair result; - auto status = - raylet_client->Wait(object_ids, numReturns, timeoutMillis, - static_cast(isWaitLocal), *current_task_id.PID, &result); + auto status = raylet_client->Wait(object_ids, numReturns, timeoutMillis, + static_cast(isWaitLocal), + current_task_id.GetId(), &result); if (ThrowRayExceptionIfNotOK(env, status)) { return nullptr; } @@ -231,15 +230,12 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId( JNIEnv *env, jclass, jbyteArray driverId, jbyteArray parentTaskId, jint parent_task_counter) { - UniqueIdFromJByteArray object_id1(env, driverId); - ray::DriverID driver_id = *object_id1.PID; + UniqueIdFromJByteArray driver_id(env, driverId); + UniqueIdFromJByteArray parent_task_id(env, parentTaskId); - UniqueIdFromJByteArray object_id2(env, parentTaskId); - ray::TaskID parent_task_id = *object_id2.PID; - - ray::TaskID task_id = - ray::GenerateTaskId(driver_id, parent_task_id, parent_task_counter); - jbyteArray result = env->NewByteArray(sizeof(ray::TaskID)); + TaskID task_id = + ray::GenerateTaskId(driver_id.GetId(), parent_task_id.GetId(), parent_task_counter); + jbyteArray result = env->NewByteArray(sizeof(TaskID)); if (nullptr == result) { return nullptr; } @@ -261,8 +257,8 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( for (int i = 0; i < len; i++) { jbyteArray object_id_bytes = static_cast(env->GetObjectArrayElement(objectIds, i)); - UniqueIdFromJByteArray object_id(env, object_id_bytes); - object_ids.push_back(*object_id.PID); + UniqueIdFromJByteArray object_id(env, object_id_bytes); + object_ids.push_back(object_id.GetId()); env->DeleteLocalRef(object_id_bytes); } auto raylet_client = reinterpret_cast(client); @@ -280,9 +276,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env jlong client, jbyteArray actorId) { auto raylet_client = reinterpret_cast(client); - UniqueIdFromJByteArray actor_id(env, actorId); + UniqueIdFromJByteArray actor_id(env, actorId); ActorCheckpointID checkpoint_id; - auto status = raylet_client->PrepareActorCheckpoint(*actor_id.PID, checkpoint_id); + auto status = raylet_client->PrepareActorCheckpoint(actor_id.GetId(), checkpoint_id); if (ThrowRayExceptionIfNotOK(env, status)) { return nullptr; } @@ -301,10 +297,10 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) { auto raylet_client = reinterpret_cast(client); - UniqueIdFromJByteArray actor_id(env, actorId); - UniqueIdFromJByteArray checkpoint_id(env, checkpointId); - auto status = - raylet_client->NotifyActorResumedFromCheckpoint(*actor_id.PID, *checkpoint_id.PID); + UniqueIdFromJByteArray actor_id(env, actorId); + UniqueIdFromJByteArray checkpoint_id(env, checkpointId); + auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id.GetId(), + checkpoint_id.GetId()); ThrowRayExceptionIfNotOK(env, status); } diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 93e56a93a81b6..949dc9eca1c21 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -358,8 +358,9 @@ void LineageCache::FlushTask(const TaskID &task_id) { auto task_data = std::make_shared(); auto root = flatbuffers::GetRoot(fbb.GetBufferPointer()); root->UnPackTo(task_data.get()); - RAY_CHECK_OK(task_storage_.Add(task->TaskData().GetTaskSpecification().DriverId(), - task_id, task_data, task_callback)); + RAY_CHECK_OK( + task_storage_.Add(JobID(task->TaskData().GetTaskSpecification().DriverId()), + task_id, task_data, task_callback)); // We successfully wrote the task, so mark it as committing. // TODO(swang): Use a batched interface and write with all object entries. diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 973483759e4bb..1ed0dcc84f393 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -113,9 +113,9 @@ static inline Task ExampleTask(const std::vector &arguments, task_arguments.emplace_back(std::make_shared(references)); } std::vector function_descriptor(3); - auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0, - task_arguments, num_returns, required_resources, - Language::PYTHON, function_descriptor); + auto spec = TaskSpecification(DriverID::nil(), TaskID::from_random(), 0, task_arguments, + num_returns, required_resources, Language::PYTHON, + function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); execution_spec.IncrementNumForwards(); Task task = Task(execution_spec, spec); diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 30f05de226c42..d18edbad8238e 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -35,7 +35,7 @@ void Monitor::Start() { HandleHeartbeat(id, heartbeat_data); }; RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe( - UniqueID::nil(), UniqueID::nil(), heartbeat_callback, nullptr, nullptr)); + JobID::nil(), ClientID::nil(), heartbeat_callback, nullptr, nullptr)); Tick(); } @@ -69,7 +69,7 @@ void Monitor::Tick() { << " has missed too many heartbeats from it."; // We use the nil JobID to broadcast the message to all drivers. RAY_CHECK_OK(gcs_client_.error_table().PushErrorToDriver( - JobID::nil(), type, error_message.str(), current_time_ms())); + DriverID::nil(), type, error_message.str(), current_time_ms())); } }; RAY_CHECK_OK(gcs_client_.client_table().Lookup(lookup_callback)); @@ -88,7 +88,7 @@ void Monitor::Tick() { batch->batch.push_back(std::unique_ptr( new HeartbeatTableDataT(heartbeat.second))); } - RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(UniqueID::nil(), UniqueID::nil(), + RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(JobID::nil(), ClientID::nil(), batch, nullptr)); heartbeat_buffer_.clear(); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 684cad003b87a..a49b6268cf4f2 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -145,7 +145,7 @@ ray::Status NodeManager::RegisterGcs() { }; RAY_RETURN_NOT_OK(gcs_client_->actor_table().Subscribe( - UniqueID::nil(), UniqueID::nil(), actor_notification_callback, nullptr)); + JobID::nil(), ClientID::nil(), actor_notification_callback, nullptr)); // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, @@ -167,17 +167,17 @@ ray::Status NodeManager::RegisterGcs() { HeartbeatBatchAdded(heartbeat_batch); }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( - UniqueID::nil(), UniqueID::nil(), heartbeat_batch_added, + JobID::nil(), ClientID::nil(), heartbeat_batch_added, /*subscribe_callback=*/nullptr, /*done_callback=*/nullptr)); // Subscribe to driver table updates. const auto driver_table_handler = [this]( - gcs::AsyncGcsClient *client, const ClientID &client_id, + gcs::AsyncGcsClient *client, const DriverID &client_id, const std::vector &driver_data) { HandleDriverTableUpdate(client_id, driver_data); }; - RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe(JobID::nil(), UniqueID::nil(), + RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe(JobID::nil(), ClientID::nil(), driver_table_handler, nullptr)); // Start sending heartbeats to the GCS. @@ -210,12 +210,12 @@ void NodeManager::KillWorker(std::shared_ptr worker) { } void NodeManager::HandleDriverTableUpdate( - const ClientID &id, const std::vector &driver_data) { + const DriverID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::from_binary(entry.driver_id) << " " << entry.is_dead; if (entry.is_dead) { - auto driver_id = UniqueID::from_binary(entry.driver_id); + auto driver_id = DriverID::from_binary(entry.driver_id); auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); // Kill all the workers. The actual cleanup for these workers is done @@ -270,7 +270,7 @@ void NodeManager::Heartbeat() { } ray::Status status = heartbeat_table.Add( - UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, + JobID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, /*success_callback=*/nullptr); RAY_CHECK_OK_PREPEND(status, "Heartbeat failed"); @@ -351,7 +351,7 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { << ". This may be since the node was recently removed."; // We use the nil JobID to broadcast the message to all drivers. RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - JobID::nil(), type, error_message.str(), current_time_ms())); + DriverID::nil(), type, error_message.str(), current_time_ms())); return; } @@ -684,7 +684,7 @@ void NodeManager::ProcessClientMessage( } break; case protocol::MessageType::NotifyUnblocked: { auto message = flatbuffers::GetRoot(message_data); - HandleTaskUnblocked(client, from_flatbuf(*message->task_id())); + HandleTaskUnblocked(client, from_flatbuf(*message->task_id())); } break; case protocol::MessageType::WaitRequest: { ProcessWaitRequestMessage(client, message_data); @@ -698,7 +698,7 @@ void NodeManager::ProcessClientMessage( } break; case protocol::MessageType::FreeObjectsInObjectStoreRequest: { auto message = flatbuffers::GetRoot(message_data); - std::vector object_ids = from_flatbuf(*message->object_ids()); + std::vector object_ids = from_flatbuf(*message->object_ids()); object_manager_.FreeObjects(object_ids, message->local_only()); } break; case protocol::MessageType::PrepareActorCheckpointRequest: { @@ -719,7 +719,7 @@ void NodeManager::ProcessClientMessage( void NodeManager::ProcessRegisterClientRequestMessage( const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); - client->SetClientID(from_flatbuf(*message->client_id())); + client->SetClientID(from_flatbuf(*message->client_id())); auto worker = std::make_shared(message->worker_pid(), message->language(), client); if (message->is_worker()) { @@ -731,11 +731,11 @@ void NodeManager::ProcessRegisterClientRequestMessage( // message is actually the ID of the driver task, while client_id represents the // real driver ID, which can associate all the tasks/actors for a given driver, // which is set to the worker ID. - const JobID driver_task_id = from_flatbuf(*message->driver_id()); - worker->AssignTaskId(driver_task_id); - worker->AssignDriverId(from_flatbuf(*message->client_id())); + const JobID driver_task_id = from_flatbuf(*message->driver_id()); + worker->AssignTaskId(TaskID(driver_task_id)); + worker->AssignDriverId(from_flatbuf(*message->client_id())); worker_pool_.RegisterDriver(std::move(worker)); - local_queues_.AddDriverTaskId(driver_task_id); + local_queues_.AddDriverTaskId(TaskID(driver_task_id)); } } @@ -865,14 +865,14 @@ void NodeManager::ProcessDisconnectClientMessage( if (!intentional_disconnect) { // Push the error to driver. - const JobID &job_id = worker->GetAssignedDriverId(); + const DriverID &driver_id = worker->GetAssignedDriverId(); // TODO(rkn): Define this constant somewhere else. std::string type = "worker_died"; std::ostringstream error_message; error_message << "A worker died or was killed while executing task " << task_id << "."; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - job_id, type, error_message.str(), current_time_ms())); + driver_id, type, error_message.str(), current_time_ms())); } } @@ -899,8 +899,9 @@ void NodeManager::ProcessDisconnectClientMessage( DispatchTasks(local_queues_.GetReadyTasksWithResources()); } else if (is_driver) { // The client is a driver. - RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientId(), - /*is_dead=*/true)); + RAY_CHECK_OK( + gcs_client_->driver_table().AppendDriverData(DriverID(client->GetClientId()), + /*is_dead=*/true)); auto driver_id = worker->GetAssignedTaskId(); RAY_CHECK(!driver_id.is_nil()); local_queues_.RemoveDriverTaskId(driver_id); @@ -919,7 +920,7 @@ void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { // Read the task submitted by the client. auto message = flatbuffers::GetRoot(message_data); TaskExecutionSpecification task_execution_spec( - from_flatbuf(*message->execution_dependencies())); + from_flatbuf(*message->execution_dependencies())); TaskSpecification task_spec(*message->task_spec()); Task task(task_execution_spec, task_spec); // Submit the task to the local scheduler. Since the task was submitted @@ -932,7 +933,7 @@ void NodeManager::ProcessFetchOrReconstructMessage( auto message = flatbuffers::GetRoot(message_data); std::vector required_object_ids; for (size_t i = 0; i < message->object_ids()->size(); ++i) { - ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i)); + ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i)); if (message->fetch_only()) { // If only a fetch is required, then do not subscribe to the // dependencies to the task dependency manager. @@ -950,7 +951,7 @@ void NodeManager::ProcessFetchOrReconstructMessage( } if (!required_object_ids.empty()) { - const TaskID task_id = from_flatbuf(*message->task_id()); + const TaskID task_id = from_flatbuf(*message->task_id()); HandleTaskBlocked(client, required_object_ids, task_id); } } @@ -959,7 +960,7 @@ void NodeManager::ProcessWaitRequestMessage( const std::shared_ptr &client, const uint8_t *message_data) { // Read the data. auto message = flatbuffers::GetRoot(message_data); - std::vector object_ids = from_flatbuf(*message->object_ids()); + std::vector object_ids = from_flatbuf(*message->object_ids()); int64_t wait_ms = message->timeout(); uint64_t num_required_objects = static_cast(message->num_ready_objects()); bool wait_local = message->wait_local(); @@ -974,7 +975,7 @@ void NodeManager::ProcessWaitRequestMessage( } } - const TaskID ¤t_task_id = from_flatbuf(*message->task_id()); + const TaskID ¤t_task_id = from_flatbuf(*message->task_id()); bool client_blocked = !required_object_ids.empty(); if (client_blocked) { HandleTaskBlocked(client, required_object_ids, current_task_id); @@ -1012,20 +1013,20 @@ void NodeManager::ProcessWaitRequestMessage( void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); - JobID job_id = from_flatbuf(*message->job_id()); + DriverID driver_id = from_flatbuf(*message->driver_id()); auto const &type = string_from_flatbuf(*message->type()); auto const &error_message = string_from_flatbuf(*message->error_message()); double timestamp = message->timestamp(); - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message, - timestamp)); + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(driver_id, type, + error_message, timestamp)); } void NodeManager::ProcessPrepareActorCheckpointRequest( const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); - ActorID actor_id = from_flatbuf(*message->actor_id()); + ActorID actor_id = from_flatbuf(*message->actor_id()); RAY_LOG(DEBUG) << "Preparing checkpoint for actor " << actor_id; const auto &actor_entry = actor_registry_.find(actor_id); RAY_CHECK(actor_entry != actor_registry_.end()); @@ -1037,15 +1038,15 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( const auto task_id = worker->GetAssignedTaskId(); const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING); // Generate checkpoint id and data. - ActorCheckpointID checkpoint_id = UniqueID::from_random(); + ActorCheckpointID checkpoint_id = ActorCheckpointID::from_random(); auto checkpoint_data = actor_entry->second.GenerateCheckpointData(actor_entry->first, task); // Write checkpoint data to GCS. RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add( - UniqueID::nil(), checkpoint_id, checkpoint_data, + JobID::nil(), checkpoint_id, checkpoint_data, [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, - const UniqueID &checkpoint_id, + const ActorCheckpointID &checkpoint_id, const ActorCheckpointDataT &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " << worker->GetActorId(); @@ -1072,8 +1073,9 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( void NodeManager::ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); - ActorID actor_id = from_flatbuf(*message->actor_id()); - ActorCheckpointID checkpoint_id = from_flatbuf(*message->checkpoint_id()); + ActorID actor_id = from_flatbuf(*message->actor_id()); + ActorCheckpointID checkpoint_id = + from_flatbuf(*message->checkpoint_id()); RAY_LOG(DEBUG) << "Actor " << actor_id << " was resumed from checkpoint " << checkpoint_id; checkpoint_id_to_restore_.emplace(actor_id, checkpoint_id); @@ -1093,12 +1095,12 @@ void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_cl switch (message_type_value) { case protocol::MessageType::ConnectClient: { auto message = flatbuffers::GetRoot(message_data); - auto client_id = from_flatbuf(*message->client_id()); + auto client_id = from_flatbuf(*message->client_id()); node_manager_client.SetClientID(client_id); } break; case protocol::MessageType::ForwardTaskRequest: { auto message = flatbuffers::GetRoot(message_data); - TaskID task_id = from_flatbuf(*message->task_id()); + TaskID task_id = from_flatbuf(*message->task_id()); Lineage uncommitted_lineage(*message); const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); @@ -1589,7 +1591,7 @@ bool NodeManager::AssignTask(const Task &task) { const std::string warning_message = worker_pool_.WarningAboutSize(); if (warning_message != "") { RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - JobID::nil(), "worker_pool_large", warning_message, current_time_ms())); + DriverID::nil(), "worker_pool_large", warning_message, current_time_ms())); } } // We couldn't assign this task, as no worker available. @@ -1902,7 +1904,6 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { // Use a copy of the cached task spec to re-execute the task. const Task task = lineage_cache_.GetTaskOrDie(task_id); ResubmitTask(task); - })); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 061ef5ef8969f..1e97c380b1f57 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -326,7 +326,7 @@ class NodeManager { /// \param id An unused value. TODO(rkn): Should this be removed? /// \param driver_data Data associated with a driver table event. /// \return Void. - void HandleDriverTableUpdate(const ClientID &id, + void HandleDriverTableUpdate(const DriverID &id, const std::vector &driver_data); /// Check if certain invariants associated with the task dependency manager diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 13e92d0c4ccce..28a51c7e10fd8 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -201,8 +201,8 @@ ray::Status RayletConnection::AtomicRequestReply( return ReadMessage(reply_type, reply_message); } -RayletClient::RayletClient(const std::string &raylet_socket, const UniqueID &client_id, - bool is_worker, const JobID &driver_id, +RayletClient::RayletClient(const std::string &raylet_socket, const ClientID &client_id, + bool is_worker, const DriverID &driver_id, const Language &language) : client_id_(client_id), is_worker_(is_worker), @@ -323,11 +323,11 @@ ray::Status RayletClient::Wait(const std::vector &object_ids, int num_ return ray::Status::OK(); } -ray::Status RayletClient::PushError(const JobID &job_id, const std::string &type, +ray::Status RayletClient::PushError(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { flatbuffers::FlatBufferBuilder fbb; auto message = ray::protocol::CreatePushErrorRequest( - fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), + fbb, to_flatbuf(fbb, driver_id), fbb.CreateString(type), fbb.CreateString(error_message), timestamp); fbb.Finish(message); @@ -373,7 +373,7 @@ ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, if (!status.ok()) return status; auto reply_message = flatbuffers::GetRoot(reply.get()); - checkpoint_id = ObjectID::from_binary(reply_message->checkpoint_id()->str()); + checkpoint_id = ActorCheckpointID::from_binary(reply_message->checkpoint_id()->str()); return ray::Status::OK(); } diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index d3ea765df65cd..2e07becfc2459 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -9,13 +9,14 @@ #include "ray/raylet/task_spec.h" #include "ray/status.h" -using ray::ActorID; using ray::ActorCheckpointID; +using ray::ActorID; +using ray::ClientID; +using ray::DriverID; using ray::JobID; using ray::ObjectID; using ray::TaskID; using ray::UniqueID; -using ray::ClientID; using MessageType = ray::protocol::MessageType; using ResourceMappingType = @@ -68,8 +69,8 @@ class RayletClient { /// additional message will be sent to register as one. /// \param driver_id The ID of the driver. This is non-nil if the client is a driver. /// \return The connection information. - RayletClient(const std::string &raylet_socket, const UniqueID &client_id, - bool is_worker, const JobID &driver_id, const Language &language); + RayletClient(const std::string &raylet_socket, const ClientID &client_id, + bool is_worker, const DriverID &driver_id, const Language &language); ray::Status Disconnect() { return conn_->Disconnect(); }; @@ -130,7 +131,7 @@ class RayletClient { /// \param The error message. /// \param The timestamp of the error. /// \return ray::Status. - ray::Status PushError(const JobID &job_id, const std::string &type, + ray::Status PushError(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp); /// Store some profile events in the GCS. diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 5e9ae6d7e5218..093f5c236261b 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -322,7 +322,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { task_lease_data->node_manager_id = ClientID::from_random().binary(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = 2 * test_period; - mock_gcs_.Add(DriverID::nil(), task_id, task_lease_data); + mock_gcs_.Add(JobID::nil(), task_id, task_lease_data); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -350,7 +350,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { task_lease_data->node_manager_id = ClientID::from_random().binary(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = reconstruction_timeout_ms_; - mock_gcs_.Add(DriverID::nil(), task_id, task_lease_data); + mock_gcs_.Add(JobID::nil(), task_id, task_lease_data); }); // Run the test for much longer than the reconstruction timeout. Run(reconstruction_timeout_ms_ * 2); @@ -404,7 +404,7 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { task_reconstruction_data->node_manager_id = ClientID::from_random().binary(); task_reconstruction_data->num_reconstructions = 0; RAY_CHECK_OK( - mock_gcs_.AppendAt(DriverID::nil(), task_id, task_reconstruction_data, nullptr, + mock_gcs_.AppendAt(JobID::nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, const TaskReconstructionDataT &data) { ASSERT_TRUE(false); }, diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index fe4364c4491f4..2f1b64a874803 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -263,7 +263,7 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { task_lease_data->node_manager_id = client_id_.hex(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = it->second.lease_period; - RAY_CHECK_OK(task_lease_table_.Add(DriverID::nil(), task_id, task_lease_data, nullptr)); + RAY_CHECK_OK(task_lease_table_.Add(JobID::nil(), task_id, task_lease_data, nullptr)); auto period = boost::posix_time::milliseconds(it->second.lease_period / 2); it->second.lease_timer->expires_from_now(period); diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index f414d74695652..e0d30bf9ebd6e 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -75,9 +75,9 @@ static inline Task ExampleTask(const std::vector &arguments, task_arguments.emplace_back(std::make_shared(references)); } std::vector function_descriptor(3); - auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0, - task_arguments, num_returns, required_resources, - Language::PYTHON, function_descriptor); + auto spec = TaskSpecification(DriverID::nil(), TaskID::from_random(), 0, task_arguments, + num_returns, required_resources, Language::PYTHON, + function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); execution_spec.IncrementNumForwards(); Task task = Task(execution_spec, spec); diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index a8c0f40fed609..da8bafc60fd4e 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -17,7 +17,7 @@ TaskArgumentByReference::TaskArgumentByReference(const std::vector &re flatbuffers::Offset TaskArgumentByReference::ToFlatbuffer( flatbuffers::FlatBufferBuilder &fbb) const { - return CreateArg(fbb, object_ids_to_flatbuf(fbb, references_)); + return CreateArg(fbb, ids_to_flatbuf(fbb, references_)); } TaskArgumentByValue::TaskArgumentByValue(const uint8_t *value, size_t length) { @@ -57,7 +57,7 @@ TaskSpecification::TaskSpecification(const std::string &string) { } TaskSpecification::TaskSpecification( - const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const Language &language, const std::vector &function_descriptor) @@ -68,7 +68,7 @@ TaskSpecification::TaskSpecification( function_descriptor) {} TaskSpecification::TaskSpecification( - const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, const int64_t max_actor_reconstructions, const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter, @@ -100,8 +100,8 @@ TaskSpecification::TaskSpecification( to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id), to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, - object_ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), - object_ids_to_flatbuf(fbb, returns), map_to_flatbuf(fbb, required_resources), + ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), + ids_to_flatbuf(fbb, returns), map_to_flatbuf(fbb, required_resources), map_to_flatbuf(fbb, required_placement_resources), language, string_vec_to_flatbuf(fbb, function_descriptor)); fbb.Finish(spec); @@ -122,15 +122,15 @@ size_t TaskSpecification::size() const { return spec_.size(); } // Task specification getter methods. TaskID TaskSpecification::TaskId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->task_id()); + return from_flatbuf(*message->task_id()); } -UniqueID TaskSpecification::DriverId() const { +DriverID TaskSpecification::DriverId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->driver_id()); + return from_flatbuf(*message->driver_id()); } TaskID TaskSpecification::ParentTaskId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->parent_task_id()); + return from_flatbuf(*message->parent_task_id()); } int64_t TaskSpecification::ParentCounter() const { auto message = flatbuffers::GetRoot(spec_.data()); @@ -168,7 +168,7 @@ int64_t TaskSpecification::NumReturns() const { ObjectID TaskSpecification::ReturnId(int64_t return_index) const { auto message = flatbuffers::GetRoot(spec_.data()); - return object_ids_from_flatbuf(*message->returns())[return_index]; + return ids_from_flatbuf(*message->returns())[return_index]; } bool TaskSpecification::ArgByRef(int64_t arg_index) const { @@ -184,7 +184,7 @@ int TaskSpecification::ArgIdCount(int64_t arg_index) const { ObjectID TaskSpecification::ArgId(int64_t arg_index, int64_t id_index) const { auto message = flatbuffers::GetRoot(spec_.data()); const auto &object_ids = - object_ids_from_flatbuf(*message->args()->Get(arg_index)->object_ids()); + ids_from_flatbuf(*message->args()->Get(arg_index)->object_ids()); return object_ids[id_index]; } @@ -232,12 +232,12 @@ bool TaskSpecification::IsActorTask() const { return !ActorId().is_nil(); } ActorID TaskSpecification::ActorCreationId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_creation_id()); + return from_flatbuf(*message->actor_creation_id()); } ObjectID TaskSpecification::ActorCreationDummyObjectId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_creation_dummy_object_id()); + return from_flatbuf(*message->actor_creation_dummy_object_id()); } int64_t TaskSpecification::MaxActorReconstructions() const { @@ -247,12 +247,12 @@ int64_t TaskSpecification::MaxActorReconstructions() const { ActorID TaskSpecification::ActorId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_id()); + return from_flatbuf(*message->actor_id()); } ActorHandleID TaskSpecification::ActorHandleId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_handle_id()); + return from_flatbuf(*message->actor_handle_id()); } int64_t TaskSpecification::ActorCounter() const { @@ -267,7 +267,7 @@ ObjectID TaskSpecification::ActorDummyObject() const { std::vector TaskSpecification::NewActorHandles() const { auto message = flatbuffers::GetRoot(spec_.data()); - return object_ids_from_flatbuf(*message->new_actor_handles()); + return ids_from_flatbuf(*message->new_actor_handles()); } } // namespace raylet diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 11e93050b9d15..baa6165c9ede7 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -96,7 +96,7 @@ class TaskSpecification { /// \param num_returns The number of values returned by the task. /// \param required_resources The task's resource demands. /// \param language The language of the worker that must execute the function. - TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id, + TaskSpecification(const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const std::vector> &task_arguments, int64_t num_returns, @@ -129,7 +129,7 @@ class TaskSpecification { /// \param language The language of the worker that must execute the function. /// \param function_descriptor The function descriptor. TaskSpecification( - const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, int64_t max_actor_reconstructions, const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter, @@ -164,7 +164,7 @@ class TaskSpecification { // TODO(swang): Finalize and document these methods. TaskID TaskId() const; - UniqueID DriverId() const; + DriverID DriverId() const; TaskID ParentTaskId() const; int64_t ParentCounter() const; std::vector FunctionDescriptor() const; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 4a7f71ea81ea5..c548fc924d674 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -75,7 +75,7 @@ static inline TaskSpecification ExampleTaskSpec( const ActorID actor_id = ActorID::nil(), const Language &language = Language::PYTHON) { std::vector function_descriptor(3); - return TaskSpecification(UniqueID::nil(), TaskID::nil(), 0, ActorID::nil(), + return TaskSpecification(DriverID::nil(), TaskID::nil(), 0, ActorID::nil(), ObjectID::nil(), 0, actor_id, ActorHandleID::nil(), 0, {}, {}, 0, {{}}, {{}}, language, function_descriptor); }