Skip to content

Commit

Permalink
Add a GCS table for the xray task flatbuffer (ray-project#1775)
Browse files Browse the repository at this point in the history
* Introduce Task flatbuffer into xray, add to GCS

* Compile and test raylet TaskTable
  • Loading branch information
stephanie-wang authored Mar 23, 2018
1 parent 72595cc commit 0ad1054
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/common/redis_module/ray_redis_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
}

static const char *table_prefixes[] = {
NULL, "TASK:", "CLIENT:", "OBJECT:", "FUNCTION:",
NULL, "TASK:", "TASK:", "CLIENT:", "OBJECT:", "FUNCTION:",
};

/// Parse a Redis string into a TablePubsub channel.
Expand Down
4 changes: 2 additions & 2 deletions src/ray/common/client_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ template <class T>
void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &error) {
if (error) {
// If there was an error, disconnect the client.
read_type_ = MessageType_DisconnectClient;
read_type_ = protocol::MessageType_DisconnectClient;
read_length_ = 0;
ProcessMessage(error);
return;
Expand Down Expand Up @@ -81,7 +81,7 @@ template <class T>
void ClientConnection<T>::ProcessMessage(const boost::system::error_code &error) {
if (error) {
// TODO(hme): Disconnect differently & remove dependency on node_manager_generated.h
read_type_ = MessageType_DisconnectClient;
read_type_ = protocol::MessageType_DisconnectClient;
}
manager_.ProcessClientMessage(this->shared_from_this(), read_type_,
read_message_.data());
Expand Down
3 changes: 3 additions & 0 deletions src/ray/gcs/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Status AsyncGcsClient::Connect(const std::string &address, int port,
RAY_RETURN_NOT_OK(context_->Connect(address, port));
object_table_.reset(new ObjectTable(context_, this));
task_table_.reset(new TaskTable(context_, this));
raylet_task_table_.reset(new raylet::TaskTable(context_, this));
client_table_.reset(new ClientTable(context_, this, client_info));
// TODO(swang): Call the client table's Connect() method here. To do this,
// we need to make sure that we are attached to an event loop first. This
Expand All @@ -41,6 +42,8 @@ ObjectTable &AsyncGcsClient::object_table() { return *object_table_; }

TaskTable &AsyncGcsClient::task_table() { return *task_table_; }

raylet::TaskTable &AsyncGcsClient::raylet_task_table() { return *raylet_task_table_; }

ClientTable &AsyncGcsClient::client_table() { return *client_table_; }

FunctionTable &AsyncGcsClient::function_table() { return *function_table_; }
Expand Down
2 changes: 2 additions & 0 deletions src/ray/gcs/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class RAY_EXPORT AsyncGcsClient {
inline ConfigTable &config_table();
ObjectTable &object_table();
TaskTable &task_table();
raylet::TaskTable &raylet_task_table();
ClientTable &client_table();
inline ErrorTable &error_table();

Expand All @@ -63,6 +64,7 @@ class RAY_EXPORT AsyncGcsClient {
std::unique_ptr<ClassTable> class_table_;
std::unique_ptr<ObjectTable> object_table_;
std::unique_ptr<TaskTable> task_table_;
std::unique_ptr<raylet::TaskTable> raylet_task_table_;
std::unique_ptr<ClientTable> client_table_;
std::shared_ptr<RedisContext> context_;
std::unique_ptr<RedisAsioClient> asio_async_client_;
Expand Down
38 changes: 27 additions & 11 deletions src/ray/gcs/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,42 @@ void LookupFailed(gcs::AsyncGcsClient *client, const UniqueID &id) {
test->Stop();
}

void TestObjectTable(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
auto data = std::make_shared<ObjectTableDataT>();
data->managers.push_back("A");
data->managers.push_back("B");
ObjectID object_id = ObjectID::from_random();
RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, &ObjectAdded));
RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, &Lookup, &LookupFailed));
void TestTableLookup(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
TaskID task_id = TaskID::from_random();
auto data = std::make_shared<protocol::TaskT>();
data->task_specification = "123";

auto add_callback = [data](gcs::AsyncGcsClient *client, const UniqueID &id,
const protocol::TaskT &d) {
ASSERT_EQ(data->task_specification, d.task_specification);
};

auto lookup_callback = [data](gcs::AsyncGcsClient *client, const UniqueID &id,
const protocol::TaskT &d) {
ASSERT_EQ(data->task_specification, d.task_specification);
test->Stop();
};

auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) {
RAY_CHECK(false);
};

RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback));
RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback,
failure_callback));
// Run the event loop. The loop will only stop if the Lookup callback is
// called (or an assertion failure).
test->Start();
}

TEST_F(TestGcsWithAe, TestObjectTable) {
TEST_F(TestGcsWithAe, TestTableLookup) {
test = this;
TestObjectTable(job_id_, client_);
TestTableLookup(job_id_, client_);
}

TEST_F(TestGcsWithAsio, TestObjectTable) {
TEST_F(TestGcsWithAsio, TestTableLookup) {
test = this;
TestObjectTable(job_id_, client_);
TestTableLookup(job_id_, client_);
}

void TestLookupFailure(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> client) {
Expand Down
2 changes: 2 additions & 0 deletions src/ray/gcs/format/gcs.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ enum Language:int {
enum TablePrefix:int {
UNUSED = 0,
TASK,
RAYLET_TASK,
CLIENT,
OBJECT,
FUNCTION
Expand All @@ -16,6 +17,7 @@ enum TablePrefix:int {
enum TablePubsub:int {
NO_PUBLISH = 0,
TASK,
RAYLET_TASK,
CLIENT,
OBJECT,
ACTOR
Expand Down
1 change: 1 addition & 0 deletions src/ray/gcs/tables.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ const ClientTableDataT &ClientTable::GetClient(const ClientID &client_id) {
}
}

template class Table<TaskID, ray::protocol::Task>;
template class Table<TaskID, TaskTableData>;
template class Table<ObjectID, ObjectTableData>;

Expand Down
13 changes: 13 additions & 0 deletions src/ray/gcs/tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "ray/gcs/format/gcs_generated.h"
#include "ray/gcs/redis_context.h"
#include "ray/raylet/format/node_manager_generated.h"

// TODO(pcm): Remove this
#include "task.h"
Expand Down Expand Up @@ -165,6 +166,18 @@ using ClassTable = Table<ClassID, ClassTableData>;
// TODO(swang): Set the pubsub channel for the actor table.
using ActorTable = Table<ActorID, ActorTableData>;

namespace raylet {

class TaskTable : public Table<TaskID, ray::protocol::Task> {
public:
TaskTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Table(context, client) {
pubsub_channel_ = TablePubsub_RAYLET_TASK;
prefix_ = TablePrefix_RAYLET_TASK;
}
};
}

class TaskTable : public Table<TaskID, TaskTableData> {
public:
TaskTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
Expand Down
19 changes: 19 additions & 0 deletions src/ray/raylet/format/node_manager.fbs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
// Local scheduler protocol specification

// TODO(swang): We put the flatbuffer types in a separate namespace for now to
// avoid conflicts with legacy Ray types.
namespace ray.protocol;

enum MessageType:int {
// Task is submitted to the local scheduler. This is sent from a worker to a
// local scheduler.
Expand Down Expand Up @@ -50,6 +54,21 @@ enum MessageType:int {
SetActorFrontier
}

table TaskExecutionSpecification {
// A list of object IDs representing the dependencies of this task that may
// change at execution time.
dependencies: [string];
// The last time this task was received for scheduling.
last_timestamp: double;
// The number of times this task was spilled back by local schedulers.
num_forwards: int;
}

table Task {
task_specification: string;
task_execution_spec: TaskExecutionSpecification;
}

table SubmitTaskRequest {
execution_dependencies: [string];
task_spec: string;
Expand Down
23 changes: 12 additions & 11 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
RAY_LOG(DEBUG) << "Message of type " << message_type;

switch (message_type) {
case MessageType_RegisterClientRequest: {
auto message = flatbuffers::GetRoot<RegisterClientRequest>(message_data);
case protocol::MessageType_RegisterClientRequest: {
auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
if (message->is_worker()) {
// Create a new worker from the registration request.
std::shared_ptr<Worker> worker(new Worker(message->worker_pid(), client));
Expand All @@ -45,14 +45,15 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
// is legacy code and should be removed once actor creation tasks are
// implemented.
flatbuffers::FlatBufferBuilder fbb;
auto reply = CreateRegisterClientReply(fbb, fbb.CreateVector(std::vector<int>()));
auto reply =
protocol::CreateRegisterClientReply(fbb, fbb.CreateVector(std::vector<int>()));
fbb.Finish(reply);
// Reply to the worker's registration request, then listen for more
// messages.
client->WriteMessage(MessageType_RegisterClientReply, fbb.GetSize(),
client->WriteMessage(protocol::MessageType_RegisterClientReply, fbb.GetSize(),
fbb.GetBufferPointer());
} break;
case MessageType_GetTask: {
case protocol::MessageType_GetTask: {
const std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
RAY_CHECK(worker);
// If the worker was assigned a task, mark it as finished.
Expand All @@ -69,16 +70,16 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
AssignTask(scheduled_tasks.front());
}
} break;
case MessageType_DisconnectClient: {
case protocol::MessageType_DisconnectClient: {
// Remove the dead worker from the pool and stop listening for messages.
const std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
if (worker) {
worker_pool_.DisconnectWorker(worker);
}
} break;
case MessageType_SubmitTask: {
case protocol::MessageType_SubmitTask: {
// Read the task submitted by the client.
auto message = flatbuffers::GetRoot<SubmitTaskRequest>(message_data);
auto message = flatbuffers::GetRoot<protocol::SubmitTaskRequest>(message_data);
TaskExecutionSpecification task_execution_spec(
from_flatbuf(*message->execution_dependencies()));
TaskSpecification task_spec(*message->task_spec());
Expand Down Expand Up @@ -152,10 +153,10 @@ void NodeManager::AssignTask(const Task &task) {

flatbuffers::FlatBufferBuilder fbb;
const TaskSpecification &spec = task.GetTaskSpecification();
auto message = CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb),
fbb.CreateVector(std::vector<int>()));
auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb),
fbb.CreateVector(std::vector<int>()));
fbb.Finish(message);
worker->Connection()->WriteMessage(MessageType_ExecuteTask, fbb.GetSize(),
worker->Connection()->WriteMessage(protocol::MessageType_ExecuteTask, fbb.GetSize(),
fbb.GetBufferPointer());
worker->AssignTaskId(spec.TaskId());
local_queues_.QueueRunningTasks(std::vector<Task>({task}));
Expand Down

0 comments on commit 0ad1054

Please sign in to comment.