Skip to content

Commit

Permalink
[C++] Add hash table to Redis-Module (ray-project#4911)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuhong Guo authored and raulchen committed Jun 7, 2019
1 parent cbc67fc commit 5eff47b
Show file tree
Hide file tree
Showing 15 changed files with 686 additions and 93 deletions.
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ flatbuffer_py_library(
"ErrorTableData.py",
"ErrorType.py",
"FunctionTableData.py",
"GcsTableEntry.py",
"GcsEntry.py",
"HeartbeatBatchTableData.py",
"HeartbeatTableData.py",
"Language.py",
Expand Down
2 changes: 1 addition & 1 deletion doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"ray.core.generated.EntryType",
"ray.core.generated.ErrorTableData",
"ray.core.generated.ErrorType",
"ray.core.generated.GcsTableEntry",
"ray.core.generated.GcsEntry",
"ray.core.generated.HeartbeatBatchTableData",
"ray.core.generated.HeartbeatTableData",
"ray.core.generated.Language",
Expand Down
2 changes: 1 addition & 1 deletion java/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ flatbuffers_generated_files = [
"ErrorTableData.java",
"ErrorType.java",
"FunctionTableData.java",
"GcsTableEntry.java",
"GcsEntry.java",
"HeartbeatBatchTableData.java",
"HeartbeatTableData.java",
"Language.java",
Expand Down
4 changes: 2 additions & 2 deletions python/ray/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.DriverTableData import DriverTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.GcsEntry import GcsEntry
from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.Language import Language
Expand All @@ -25,7 +25,7 @@
"ClientTableData",
"DriverTableData",
"ErrorTableData",
"GcsTableEntry",
"GcsEntry",
"HeartbeatBatchTableData",
"HeartbeatTableData",
"Language",
Expand Down
6 changes: 2 additions & 4 deletions python/ray/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def subscribe(self, channel):
def xray_heartbeat_batch_handler(self, unused_channel, data):
"""Handle an xray heartbeat batch message from Redis."""

gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0)
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
heartbeat_data = gcs_entries.Entries(0)

message = (ray.gcs_utils.HeartbeatBatchTableData.
Expand Down Expand Up @@ -208,8 +207,7 @@ def xray_driver_removed_handler(self, unused_channel, data):
unused_channel: The message channel.
data: The message data.
"""
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0)
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
driver_data = gcs_entries.Entries(0)
message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
driver_data, 0)
Expand Down
22 changes: 8 additions & 14 deletions python/ray/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _parse_client_table(redis_client):
return []

node_info = {}
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)

ordered_client_ids = []

Expand Down Expand Up @@ -248,8 +248,7 @@ def _object_table(self, object_id):
object_id.binary())
if message is None:
return {}
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)

assert gcs_entry.EntriesLength() > 0

Expand Down Expand Up @@ -307,8 +306,7 @@ def _task_table(self, task_id):
"", task_id.binary())
if message is None:
return {}
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)

assert gcs_entries.EntriesLength() == 1

Expand Down Expand Up @@ -431,8 +429,7 @@ def _profile_table(self, batch_id):
if message is None:
return []

gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)

profile_events = []
for i in range(gcs_entries.EntriesLength()):
Expand Down Expand Up @@ -815,9 +812,8 @@ def available_resources(self):
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
continue
data = raw_message["data"]
gcs_entries = (
ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0))
gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
data, 0))
heartbeat_data = gcs_entries.Entries(0)
message = (ray.gcs_utils.HeartbeatTableData.
GetRootAsHeartbeatTableData(heartbeat_data, 0))
Expand Down Expand Up @@ -871,8 +867,7 @@ def _error_messages(self, driver_id):
if message is None:
return []

gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
error_messages = []
for i in range(gcs_entries.EntriesLength()):
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
Expand Down Expand Up @@ -934,8 +929,7 @@ def actor_checkpoint_info(self, actor_id):
)
if message is None:
return None
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
entry = (
ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData(
gcs_entry.Entries(0), 0))
Expand Down
2 changes: 1 addition & 1 deletion python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
if msg is None:
threads_stopped.wait(timeout=0.01)
continue
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
msg["data"], 0)
assert gcs_entry.EntriesLength() == 1
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
Expand Down
3 changes: 3 additions & 0 deletions src/ray/gcs/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port,
profile_table_.reset(new ProfileTable(shard_contexts_, this));
actor_checkpoint_table_.reset(new ActorCheckpointTable(shard_contexts_, this));
actor_checkpoint_id_table_.reset(new ActorCheckpointIdTable(shard_contexts_, this));
resource_table_.reset(new DynamicResourceTable({primary_context_}, this));
command_type_ = command_type;

// TODO(swang): Call the client table's Connect() method here. To do this,
Expand Down Expand Up @@ -229,6 +230,8 @@ ActorCheckpointIdTable &AsyncGcsClient::actor_checkpoint_id_table() {
return *actor_checkpoint_id_table_;
}

DynamicResourceTable &AsyncGcsClient::resource_table() { return *resource_table_; }

} // namespace gcs

} // namespace ray
2 changes: 2 additions & 0 deletions src/ray/gcs/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class RAY_EXPORT AsyncGcsClient {
ProfileTable &profile_table();
ActorCheckpointTable &actor_checkpoint_table();
ActorCheckpointIdTable &actor_checkpoint_id_table();
DynamicResourceTable &resource_table();

// We also need something to export generic code to run on workers from the
// driver (to set the PYTHONPATH)
Expand Down Expand Up @@ -94,6 +95,7 @@ class RAY_EXPORT AsyncGcsClient {
std::unique_ptr<ClientTable> client_table_;
std::unique_ptr<ActorCheckpointTable> actor_checkpoint_table_;
std::unique_ptr<ActorCheckpointIdTable> actor_checkpoint_id_table_;
std::unique_ptr<DynamicResourceTable> resource_table_;
// The following contexts write to the data shard
std::vector<std::shared_ptr<RedisContext>> shard_contexts_;
std::vector<std::unique_ptr<RedisAsioClient>> shard_asio_async_clients_;
Expand Down
172 changes: 162 additions & 10 deletions src/ray/gcs/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -657,13 +657,12 @@ void TestSetSubscribeAll(const DriverID &driver_id,

// Callback for a notification.
auto notification_callback = [object_ids, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsTableNotificationMode notification_mode,
gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode,
const std::vector<ObjectTableDataT> data) {
if (test->NumCallbacks() < 3 * 3) {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
} else {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::REMOVE);
ASSERT_EQ(change_mode, GcsChangeMode::REMOVE);
}
ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]);
// Check that we get notifications in the same order as the writes.
Expand Down Expand Up @@ -894,10 +893,9 @@ void TestSetSubscribeId(const DriverID &driver_id,
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [object_id2, managers2](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsTableNotificationMode notification_mode,
gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, object_id2);
// Check that we get notifications in the same order as the writes.
Expand Down Expand Up @@ -1111,10 +1109,9 @@ void TestSetSubscribeCancel(const DriverID &driver_id,
// The callback for a notification from the object table. This should only be
// received for the object that we requested notifications for.
auto notification_callback = [object_id, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsTableNotificationMode notification_mode,
gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
ASSERT_EQ(id, object_id);
// Check that we get a duplicate notification for the first write. We get a
// duplicate notification because notifications
Expand Down Expand Up @@ -1307,6 +1304,161 @@ TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) {
TestClientTableMarkDisconnected(driver_id_, client_);
}

void TestHashTable(const DriverID &driver_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
const int expected_count = 14;
ClientID client_id = ClientID::FromRandom();
// Prepare the first resource map: data_map1.
auto cpu_data = std::make_shared<RayResourceT>();
cpu_data->resource_name = "CPU";
cpu_data->resource_capacity = 100;
auto gpu_data = std::make_shared<RayResourceT>();
gpu_data->resource_name = "GPU";
gpu_data->resource_capacity = 2;
DynamicResourceTable::DataMap data_map1;
data_map1.emplace("CPU", cpu_data);
data_map1.emplace("GPU", gpu_data);
// Prepare the second resource map: data_map2 which decreases CPU,
// increases GPU and add a new CUSTOM compared to data_map1.
auto data_cpu = std::make_shared<RayResourceT>();
data_cpu->resource_name = "CPU";
data_cpu->resource_capacity = 50;
auto data_gpu = std::make_shared<RayResourceT>();
data_gpu->resource_name = "GPU";
data_gpu->resource_capacity = 10;
auto data_custom = std::make_shared<RayResourceT>();
data_custom->resource_name = "CUSTOM";
data_custom->resource_capacity = 2;
DynamicResourceTable::DataMap data_map2;
data_map2.emplace("CPU", data_cpu);
data_map2.emplace("GPU", data_gpu);
data_map2.emplace("CUSTOM", data_custom);
data_map2["CPU"]->resource_capacity = 50;
// This is a common comparison function for the test.
auto compare_test = [](const DynamicResourceTable::DataMap &data1,
const DynamicResourceTable::DataMap &data2) {
ASSERT_EQ(data1.size(), data2.size());
for (const auto &data : data1) {
auto iter = data2.find(data.first);
ASSERT_TRUE(iter != data2.end());
ASSERT_EQ(iter->second->resource_name, data.second->resource_name);
ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity);
}
};
auto subscribe_callback = [](AsyncGcsClient *client) {
ASSERT_TRUE(true);
test->IncrementNumCallbacks();
};
auto notification_callback = [data_map1, data_map2, compare_test](
AsyncGcsClient *client, const ClientID &id, const GcsChangeMode change_mode,
const DynamicResourceTable::DataMap &data) {
if (change_mode == GcsChangeMode::REMOVE) {
ASSERT_EQ(data.size(), 2);
ASSERT_TRUE(data.find("GPU") != data.end());
ASSERT_TRUE(data.find("CUSTOM") != data.end() || data.find("CPU") != data.end());
// The key "None-Existent" will not appear in the notification.
} else {
if (data.size() == 2) {
compare_test(data_map1, data);
} else if (data.size() == 3) {
compare_test(data_map2, data);
} else {
ASSERT_TRUE(false);
}
}
test->IncrementNumCallbacks();
// It is not sure which of the notification or lookup callback will come first.
if (test->NumCallbacks() == expected_count) {
test->Stop();
}
};
// Step 0: Subscribe the change of the hash table.
RAY_CHECK_OK(client->resource_table().Subscribe(
driver_id, ClientID::Nil(), notification_callback, subscribe_callback));
RAY_CHECK_OK(client->resource_table().RequestNotifications(
driver_id, client_id, client->client_table().GetLocalClientId()));

// Step 1: Add elements to the hash table.
auto update_callback1 = [data_map1, compare_test](
AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
compare_test(data_map1, callback_data);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(
client->resource_table().Update(driver_id, client_id, data_map1, update_callback1));
auto lookup_callback1 = [data_map1, compare_test](
AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
compare_test(data_map1, callback_data);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback1));

// Step 2: Decrease one element, increase one and add a new one.
RAY_CHECK_OK(client->resource_table().Update(driver_id, client_id, data_map2, nullptr));
auto lookup_callback2 = [data_map2, compare_test](
AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
compare_test(data_map2, callback_data);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback2));
std::vector<std::string> delete_keys({"GPU", "CUSTOM", "None-Existent"});
auto remove_callback = [delete_keys](AsyncGcsClient *client, const ClientID &id,
const std::vector<std::string> &callback_data) {
for (int i = 0; i < callback_data.size(); ++i) {
// All deleting keys exist in this argument even if the key doesn't exist.
ASSERT_EQ(callback_data[i], delete_keys[i]);
}
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->resource_table().RemoveEntries(driver_id, client_id, delete_keys,
remove_callback));
DynamicResourceTable::DataMap data_map3(data_map2);
data_map3.erase("GPU");
data_map3.erase("CUSTOM");
auto lookup_callback3 = [data_map3, compare_test](
AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
compare_test(data_map3, callback_data);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback3));

// Step 3: Reset the the resources to data_map1.
RAY_CHECK_OK(
client->resource_table().Update(driver_id, client_id, data_map1, update_callback1));
auto lookup_callback4 = [data_map1, compare_test](
AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
compare_test(data_map1, callback_data);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback4));

// Step 4: Removing all elements will remove the home Hash table from GCS.
RAY_CHECK_OK(client->resource_table().RemoveEntries(
driver_id, client_id, {"GPU", "CPU", "CUSTOM", "None-Existent"}, nullptr));
auto lookup_callback5 = [](AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
ASSERT_EQ(callback_data.size(), 0);
test->IncrementNumCallbacks();
// It is not sure which of notification or lookup callback will come first.
if (test->NumCallbacks() == expected_count) {
test->Stop();
}
};
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback5));
test->Start();
ASSERT_EQ(test->NumCallbacks(), expected_count);
}

TEST_F(TestGcsWithAsio, TestHashTable) {
test = this;
TestHashTable(driver_id_, client_);
}

#undef TEST_MACRO

} // namespace gcs
Expand Down
Loading

0 comments on commit 5eff47b

Please sign in to comment.