Skip to content

Commit

Permalink
Fix for incorrect channel read behavior after accelerated DAG teardown (
Browse files Browse the repository at this point in the history
ray-project#46320)

Prior to this PR (described in ray-project#46284), calling `ray.get()` on a
`CompiledDAGRef` (i.e., a channel) after DAG teardown would return a
large series of zeroes. This issue could be reproduced with this script:
```
import ray
from ray.dag import InputNode

@ray.remote
class Actor:
    def foo(self, arg):
        return arg
        
a = Actor.remote()
with InputNode() as inp:
    dag = a.foo.bind(inp)
    
dag = dag.experimental_compile()
x = dag.execute(1)
dag.teardown()
# `ray.get(x)` returns a large series of zeroes.
print(ray.get(x))
```

This issue happened because the channel was unregistered with the
mutable object manager on DAG teardown, and thus on a subsequent access
to the channel, the core worker thought the channel reference was for a
normal immutable Ray object rather than for a channel mutable object.
Thus, the core worker was returning the raw underlying memory for the
mutable object, and the memory buffers were sized equal to the total
size of the underlying memory, not the amount of data in the mutable
object.

This PR fixes this issue by properly checking that a channel is either
currently registered or previously registered, rather than just checking
only that the channel is currently registered.

Signed-off-by: Jack Humphries <[email protected]>
  • Loading branch information
jackhumphries authored Jul 1, 2024
1 parent 7be6836 commit 8a0d633
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 6 deletions.
20 changes: 20 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,26 @@ def test_composite_channel_multi_output(self, ray_start_regular_shared):
compiled_dag.teardown()


def test_channel_access_after_close(ray_start_regular_shared):
# Tests that an access to a channel after accelerated DAG teardown raises a
# RayChannelError exception as the channel is closed (see issue #46284).
@ray.remote
class Actor:
def foo(self, arg):
return arg

a = Actor.remote()
with InputNode() as inp:
dag = a.foo.bind(inp)

dag = dag.experimental_compile()
ref = dag.execute(1)
dag.teardown()

with pytest.raises(RayChannelError, match="Channel closed."):
ray.get(ref)


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
10 changes: 9 additions & 1 deletion src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1554,8 +1554,16 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids,
// Check whether these are experimental.Channel objects.
bool is_experimental_channel = false;
for (const ObjectID &id : ids) {
if (experimental_mutable_object_provider_->ReaderChannelRegistered(id)) {
Status status = experimental_mutable_object_provider_->GetChannelStatus(id);
if (status.ok()) {
is_experimental_channel = true;
// We continue rather than break because we want to check that *all* of the
// objects are either experimental or not experimental. We cannot have a mix of
// the two.
continue;
} else if (status.IsChannelError()) {
// The channel has been closed.
return status;
} else if (is_experimental_channel) {
return Status::NotImplemented(
"ray.get can only be called on all normal objects, or all "
Expand Down
13 changes: 13 additions & 0 deletions src/ray/core_worker/experimental_mutable_object_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,15 @@ Status MutableObjectManager::SetErrorAll() {
return ret;
}

Status MutableObjectManager::IsChannelClosed(const ObjectID &object_id) {
Channel *channel = GetChannel(object_id);
if (!channel) {
return Status::NotFound(
absl::StrFormat("Could not find channel for object ID %s.", object_id.Hex()));
}
return channel->mutable_object->header->CheckHasError();
}

#else // defined(__APPLE__) || defined(__linux__)

MutableObjectManager::~MutableObjectManager() {}
Expand Down Expand Up @@ -498,6 +507,10 @@ Status MutableObjectManager::SetErrorAll() {
return Status::NotImplemented("Not supported on Windows.");
}

Status MutableObjectManager::IsChannelClosed(const ObjectID &object_id) {
return Status::NotImplemented("Not supported on Windows.");
}

#endif

} // namespace experimental
Expand Down
23 changes: 18 additions & 5 deletions src/ray/core_worker/experimental_mutable_object_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,15 @@ class MutableObjectManager : public std::enable_shared_from_this<MutableObjectMa
/// Checks if a channel is registered for an object.
///
/// \param[in] object_id The ID of the object.
/// The return status. True if the channel is registered for object_id, false otherwise.
/// \return The return status. True if the channel is registered for object_id, false
/// otherwise.
bool ChannelRegistered(const ObjectID &object_id) { return GetChannel(object_id); }

/// Checks if a reader channel is registered for an object.
///
/// \param[in] object_id The ID of the object.
/// The return status. True if the channel is registered as a reader for object_id,
/// false otherwise.
/// \return The return status. True if the channel is registered as a reader for
/// object_id, false otherwise.
bool ReaderChannelRegistered(const ObjectID &object_id) {
Channel *c = GetChannel(object_id);
if (!c) {
Expand All @@ -122,8 +123,8 @@ class MutableObjectManager : public std::enable_shared_from_this<MutableObjectMa
/// Checks if a writer channel is registered for an object.
///
/// \param[in] object_id The ID of the object.
/// The return status. True if the channel is registered as a writer for object_id,
/// false otherwise.
/// \return The return status. True if the channel is registered as a writer for
/// object_id, false otherwise.
bool WriterChannelRegistered(const ObjectID &object_id) {
Channel *c = GetChannel(object_id);
if (!c) {
Expand Down Expand Up @@ -188,6 +189,18 @@ class MutableObjectManager : public std::enable_shared_from_this<MutableObjectMa
/// an error on acquire.
Status SetErrorAll();

/// Checks if the channel is closed.
///
/// \param[in] object_id The ID of the object.
/// \return Status indicating whether the object (if a channel for it exists) has its
/// error bit set (i.e., the channel is closed).
Status IsChannelClosed(const ObjectID &object_id);

/// Returns the channel for object_id. If no channel exists for object_id, returns
/// nullptr.
///
/// \param[in] object_id The ID of the object.
/// \return The channel or nullptr.
Channel *GetChannel(const ObjectID &object_id);

private:
Expand Down
7 changes: 7 additions & 0 deletions src/ray/core_worker/experimental_mutable_object_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ Status MutableObjectProvider::SetError(const ObjectID &object_id) {
return object_manager_->SetError(object_id);
}

Status MutableObjectProvider::GetChannelStatus(const ObjectID &object_id) {
if (ReaderChannelRegistered(object_id)) {
return Status::OK();
}
return object_manager_->IsChannelClosed(object_id);
}

void MutableObjectProvider::PollWriterClosure(
instrumented_io_context &io_context,
const ObjectID &object_id,
Expand Down
12 changes: 12 additions & 0 deletions src/ray/core_worker/experimental_mutable_object_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ class MutableObjectProvider {
/// \param[in] object_id The ID of the object.
Status SetError(const ObjectID &object_id);

/// Returns the current status of the channel for the object. Possible statuses are:
/// 1. Status::OK()
// - The channel is registered and open.
/// 2. Status::ChannelError()
/// - The channel was registered and previously open, but is now closed.
/// 3. Status::NotFound()
/// - No channel exists for this object.
///
/// \param[in] object_id The ID of the object.
/// \return Current status of the channel.
Status GetChannelStatus(const ObjectID &object_id);

private:
struct LocalReaderInfo {
int64_t num_readers;
Expand Down

0 comments on commit 8a0d633

Please sign in to comment.