Skip to content

Commit

Permalink
Accelerated DAG: Support channel writes larger than the max gRPC payl…
Browse files Browse the repository at this point in the history
…oad size (ray-project#46498)

Signed-off-by: Jack Humphries <[email protected]>
  • Loading branch information
jackhumphries authored Jul 15, 2024
1 parent 4e75921 commit 9752a23
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 95 deletions.
11 changes: 0 additions & 11 deletions python/ray/dag/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@
# The maximum memory usage for buffered results is 1 GB.
DEFAULT_MAX_BUFFERED_RESULTS = int(os.environ.get("RAY_DAG_max_buffered_results", 1000))

# We still need to add support for transferring objects that are larger than the gRPC
# payload limit, which Ray sets to ~512 MiB (so we set it slightly lower here to be
# safe).
# TODO(jhumphri): Add support for transferring objects that are larger than the gRPC
# payload limit. We can support this by breaking an object into multiple RPCs.
MAX_GRPC_PAYLOAD = int(1024 * 1024 * 450) # 450 MiB


@DeveloperAPI
@dataclass
Expand Down Expand Up @@ -60,17 +53,13 @@ class DAGContext:
executions is beyond the DAG capacity, the new execution would
be blocked in the first place; therefore, this limit is only
enforced when it is smaller than the DAG capacity.
max_grpc_payload: The maximum payload size that fits within a single gRPC.
Currently, mutable objects larger than this size cannot can sent via a
multi-node channel, though we plan to support this in the future.
"""

execution_timeout: int = DEFAULT_EXECUTION_TIMEOUT_S
retrieval_timeout: int = DEFAULT_RETRIEVAL_TIMEOUT_S
buffer_size_bytes: int = DEFAULT_BUFFER_SIZE_BYTES
asyncio_max_queue_size: int = DEFAULT_ASYNCIO_MAX_QUEUE_SIZE
max_buffered_results: int = DEFAULT_MAX_BUFFERED_RESULTS
max_grpc_payload: int = MAX_GRPC_PAYLOAD

@staticmethod
def get_current() -> "DAGContext":
Expand Down
54 changes: 54 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest

from ray.exceptions import RayChannelError, RayChannelTimeoutError
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
import ray
import ray._private
import ray.cluster_utils
Expand Down Expand Up @@ -1261,6 +1262,59 @@ def test_driver_and_actor_as_readers(ray_start_cluster):
dag.experimental_compile()


def test_payload_large(ray_start_cluster):
cluster = ray_start_cluster
# This node is for the driver (including the CompiledDAG.DAGDriverProxyActor).
first_node_handle = cluster.add_node(num_cpus=1)
# This node is for the reader.
second_node_handle = cluster.add_node(num_cpus=1)
ray.init(address=cluster.address)
cluster.wait_for_nodes()

nodes = [first_node_handle.node_id, second_node_handle.node_id]
# We want to check that there are two nodes. Thus, we convert `nodes` to a set and
# then back to a list to remove duplicates. Then we check that the length of `nodes`
# is 2.
nodes = list(set(nodes))
assert len(nodes) == 2

def create_actor(node):
return Actor.options(
scheduling_strategy=NodeAffinitySchedulingStrategy(node, soft=False)
).remote(0)

def get_node_id(self):
return ray.get_runtime_context().get_node_id()

driver_node = get_node_id(None)
nodes.remove(driver_node)

a = create_actor(nodes[0])
a_node = ray.get(a.__ray_call__.remote(get_node_id))
assert a_node == nodes[0]
# Check that the driver and actor are on different nodes.
assert driver_node != a_node

with InputNode() as i:
dag = a.echo.bind(i)

compiled_dag = dag.experimental_compile()

# Ray sets the gRPC payload max size to 512 MiB. We choose a size in this test that
# is a bit larger.
size = 1024 * 1024 * 600
val = b"x" * size

for i in range(3):
ref = compiled_dag.execute(val)
result = ray.get(ref)
assert result == val

# Note: must teardown before starting a new Ray session, otherwise you'll get
# a segfault from the dangling monitor thread upon the new Ray init.
compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
23 changes: 5 additions & 18 deletions python/ray/experimental/channel/shared_memory_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,11 @@ def __init__(

self._num_readers = len(self._readers)
if self.is_remote():
from ray.dag.context import DAGContext

if typ.buffer_size_bytes > DAGContext.get_current().max_grpc_payload:
raise ValueError(
"The reader and writer are on different nodes, so the object "
"written to the channel must have a size less than or equal to "
"the max gRPC payload size "
f"({DAGContext.get_current().max_grpc_payload} bytes)."
)
# Even though there may be multiple readers on a remote node, we set
# `self._num_readers` to 1 here. On this local node, only the IO thread in
# the mutable object provider will read the mutable object. The IO thread
# will then send a gRPC with the mutable object contents to the remote node
# where the readers are.
self._num_readers = 1

def _create_reader_ref(
Expand Down Expand Up @@ -417,15 +413,6 @@ def _resize_channel_if_needed(self, serialized_value: str, timeout_ms: int):
# include the size of the metadata, so we must account for the size of the
# metadata explicitly.
size = serialized_value.total_bytes + len(serialized_value.metadata)

from ray.dag.context import DAGContext

if size > DAGContext.get_current().max_grpc_payload and self.is_remote():
raise ValueError(
"The reader and writer are on different nodes, so the object written "
"to the channel must have a size less than or equal to the max gRPC "
f"payload size ({DAGContext.get_current().max_grpc_payload} bytes)."
)
if size > self._typ.buffer_size_bytes:
# Now make the channel backing store larger.
self._typ.buffer_size_bytes = size
Expand Down
48 changes: 23 additions & 25 deletions python/ray/tests/test_channel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# coding: utf-8
import logging
import os
import re
import sys
import time
import traceback
Expand Down Expand Up @@ -302,9 +301,6 @@ def test_multiple_channels_different_nodes(ray_start_cluster):

@ray.remote(num_cpus=1)
class Actor:
def __init__(self):
pass

def read(self, channel, val):
read_val = channel.read()
if isinstance(val, np.ndarray):
Expand Down Expand Up @@ -942,7 +938,7 @@ def write(self, write_error):
sys.platform != "linux" and sys.platform != "darwin",
reason="Requires Linux or Mac.",
)
def test_payload_too_large(ray_start_cluster):
def test_payload_large(ray_start_cluster):
cluster = ray_start_cluster
# This node is for the driver.
first_node_handle = cluster.add_node(num_cpus=1)
Expand All @@ -963,6 +959,9 @@ class Actor:
def get_node_id(self):
return ray.get_runtime_context().get_node_id()

def read(self, channel, val):
assert channel.read() == val

def create_actor(node):
return Actor.options(
scheduling_strategy=NodeAffinitySchedulingStrategy(node, soft=False)
Expand All @@ -974,22 +973,21 @@ def create_actor(node):
a = create_actor(actor_node)
assert driver_node != ray.get(a.get_node_id.remote())

with pytest.raises(
ValueError,
match=re.escape(
"The reader and writer are on different nodes, so the object written to "
"the channel must have a size less than or equal to the max gRPC payload "
"size (471859200 bytes)."
),
):
ray_channel.Channel(None, [a], 1024 * 1024 * 512)
# Ray sets the gRPC payload max size to 512 MiB. We choose a size in this test that
# is a bit larger.
size = 1024 * 1024 * 600
ch = ray_channel.Channel(None, [a], size)

val = b"x" * size
ch.write(val)
ray.get(a.read.remote(ch, val))


@pytest.mark.skipif(
sys.platform != "linux" and sys.platform != "darwin",
reason="Requires Linux or Mac.",
)
def test_payload_resize_too_large(ray_start_cluster):
def test_payload_resize_large(ray_start_cluster):
cluster = ray_start_cluster
# This node is for the driver.
first_node_handle = cluster.add_node(num_cpus=1)
Expand All @@ -1010,6 +1008,9 @@ class Actor:
def get_node_id(self):
return ray.get_runtime_context().get_node_id()

def read(self, channel, val):
assert channel.read() == val

def create_actor(node):
return Actor.options(
scheduling_strategy=NodeAffinitySchedulingStrategy(node, soft=False)
Expand All @@ -1021,17 +1022,14 @@ def create_actor(node):
a = create_actor(actor_node)
assert driver_node != ray.get(a.get_node_id.remote())

chan = ray_channel.Channel(None, [a], 1000)
ch = ray_channel.Channel(None, [a], 1000)

with pytest.raises(
ValueError,
match=re.escape(
"The reader and writer are on different nodes, so the object written to "
"the channel must have a size less than or equal to the max gRPC payload "
"size (471859200 bytes)."
),
):
chan.write(b"x" * (1024 * 1024 * 512))
# Ray sets the gRPC payload max size to 512 MiB. We choose a size in this test that
# is a bit larger.
size = 1024 * 1024 * 600
val = b"x" * size
ch.write(val)
ray.get(a.read.remote(ch, val))


@pytest.mark.skipif(
Expand Down
19 changes: 19 additions & 0 deletions src/ray/core_worker/experimental_mutable_object_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,25 @@ Status MutableObjectManager::WriteAcquire(const ObjectID &object_id,
return Status::OK();
}

Status MutableObjectManager::GetObjectBackingStore(const ObjectID &object_id,
int64_t data_size,
int64_t metadata_size,
std::shared_ptr<Buffer> &data) {
RAY_LOG(DEBUG) << "WriteGetObjectBackingStore " << object_id;
absl::ReaderMutexLock guard(&destructor_lock_);

Channel *channel = GetChannel(object_id);
if (!channel) {
return Status::ChannelError("Channel has not been registered");
}
RAY_CHECK(channel->written);

std::unique_ptr<plasma::MutableObject> &object = channel->mutable_object;
int64_t total_size = data_size + metadata_size;
data = SharedMemoryBuffer::Slice(object->buffer, 0, total_size);
return Status::OK();
}

Status MutableObjectManager::WriteRelease(const ObjectID &object_id) {
RAY_LOG(DEBUG) << "WriteRelease " << object_id;
absl::ReaderMutexLock guard(&destructor_lock_);
Expand Down
13 changes: 13 additions & 0 deletions src/ray/core_worker/experimental_mutable_object_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ class MutableObjectManager : public std::enable_shared_from_this<MutableObjectMa
/// otherwise.
bool ChannelRegistered(const ObjectID &object_id) { return GetChannel(object_id); }

/// Gets the backing store for an object. WriteAcquire() must have already been called
/// before this method is called, and WriteRelease() must not yet have been called.
///
/// \param[in] object_id The ID of the object.
/// \param[in] data_size The size of the data in the object.
/// \param[in] metadata_size The size of the metadata in the object.
/// \param[out] data The mutable object buffer in plasma that can be written to.
/// \return The return status.
Status GetObjectBackingStore(const ObjectID &object_id,
int64_t data_size,
int64_t metadata_size,
std::shared_ptr<Buffer> &data);

/// Acquires a write lock on the object that prevents readers from reading
/// until we are done writing. This is safe for concurrent writers.
///
Expand Down
72 changes: 52 additions & 20 deletions src/ray/core_worker/experimental_mutable_object_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,33 +102,65 @@ void MutableObjectProvider::HandleRegisterMutableObject(
void MutableObjectProvider::HandlePushMutableObject(
const rpc::PushMutableObjectRequest &request, rpc::PushMutableObjectReply *reply) {
LocalReaderInfo info;
const ObjectID writer_object_id = ObjectID::FromBinary(request.writer_object_id());
{
const ObjectID writer_object_id = ObjectID::FromBinary(request.writer_object_id());
absl::MutexLock guard(&remote_writer_object_to_local_reader_lock_);
auto it = remote_writer_object_to_local_reader_.find(writer_object_id);
RAY_CHECK(it != remote_writer_object_to_local_reader_.end());
info = it->second;
}
size_t data_size = request.data_size();
size_t metadata_size = request.metadata_size();

// Copy both the data and metadata to a local channel.
std::shared_ptr<Buffer> data;
const uint8_t *metadata_ptr =
reinterpret_cast<const uint8_t *>(request.data().data()) + request.data_size();
RAY_CHECK_OK(object_manager_->WriteAcquire(info.local_object_id,
data_size,
metadata_ptr,
metadata_size,
info.num_readers,
data));
RAY_CHECK(data);

size_t total_size = data_size + metadata_size;
size_t total_data_size = request.total_data_size();
size_t total_metadata_size = request.total_metadata_size();
size_t total_size = total_data_size + total_metadata_size;

uint64_t offset = request.offset();
uint64_t chunk_size = request.chunk_size();

uint64_t tmp_written_so_far = 0;
{
absl::MutexLock guard(&written_so_far_lock_);

tmp_written_so_far = written_so_far_[writer_object_id];
written_so_far_[writer_object_id] += chunk_size;
if (written_so_far_[writer_object_id] == total_size) {
written_so_far_[writer_object_id] = 0;
}
}

std::shared_ptr<Buffer> object_backing_store;
if (!tmp_written_so_far) {
// We set `metadata` to nullptr since the metadata is at the end of the object, which
// we will not have until the last chunk is received (or until the two last chunks are
// received, if the metadata happens to span both). The metadata will end up being
// written along with the data as the chunks are written.
RAY_CHECK_OK(object_manager_->WriteAcquire(info.local_object_id,
total_data_size,
/*metadata=*/nullptr,
total_metadata_size,
info.num_readers,
object_backing_store));
} else {
RAY_CHECK_OK(object_manager_->GetObjectBackingStore(info.local_object_id,
total_data_size,
total_metadata_size,
object_backing_store));
}
RAY_CHECK(object_backing_store);

// The buffer has the data immediately followed by the metadata. `WriteAcquire()`
// above checks that the buffer size is at least `total_size`.
memcpy(data->Data(), request.data().data(), total_size);
RAY_CHECK_OK(object_manager_->WriteRelease(info.local_object_id));
// above checks that the buffer size is large enough to hold both the data and the
// metadata.
memcpy(object_backing_store->Data() + offset, request.payload().data(), chunk_size);

size_t total_written = tmp_written_so_far + chunk_size;
RAY_CHECK_LE(total_written, total_size);
if (total_written == total_size) {
// The entire object has been written, so call `WriteRelease()`.
RAY_CHECK_OK(object_manager_->WriteRelease(info.local_object_id));
reply->set_done(true);
} else {
reply->set_done(false);
}
}

Status MutableObjectProvider::WriteAcquire(const ObjectID &object_id,
Expand Down
8 changes: 8 additions & 0 deletions src/ray/core_worker/experimental_mutable_object_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ class MutableObjectProvider {
// and then send the changes to remote nodes via the network.
std::vector<std::unique_ptr<std::thread>> io_threads_;

// Protects the `written_so_far_` map.
absl::Mutex written_so_far_lock_;
// For objects larger than the gRPC max payload size *that this node receives from a
// writer node*, this map tracks how many bytes have been received so far for a single
// object write.
std::unordered_map<ObjectID, uint64_t> written_so_far_
ABSL_GUARDED_BY(written_so_far_lock_);

friend class MutableObjectProvider_MutableObjectBufferReadRelease_Test;
};

Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker/test/mutable_object_provider_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ TEST(MutableObjectProvider, HandlePushMutableObject) {

ray::rpc::PushMutableObjectRequest request;
request.set_writer_object_id(object_id.Binary());
request.set_data_size(0);
request.set_metadata_size(0);
request.set_total_data_size(0);
request.set_total_metadata_size(0);

ray::rpc::PushMutableObjectReply reply;
provider.HandlePushMutableObject(request, &reply);
Expand Down
Loading

0 comments on commit 9752a23

Please sign in to comment.