Skip to content

Commit

Permalink
[Core][Streaming Generator] Fix memory leak from the end of object st…
Browse files Browse the repository at this point in the history
…ream object (ray-project#38152)


---------

Signed-off-by: Edward Oakes <[email protected]>
Co-authored-by: Edward Oakes <[email protected]>
  • Loading branch information
rkooo567 and edoakes authored Aug 8, 2023
1 parent 75a700a commit 373f5f1
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 15 deletions.
3 changes: 3 additions & 0 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3238,6 +3238,9 @@ cdef class CoreWorker:
logger.warning("Local object store memory usage:\n{}\n".format(
message.decode("utf-8")))

def get_memory_store_size(self):
return CCoreWorkerProcess.GetCoreWorker().GetMemoryStoreSize()

cdef python_label_match_expressions_to_c(
self, python_expressions,
CLabelMatchExpressions *c_expressions):
Expand Down
1 change: 1 addition & 0 deletions python/ray/includes/libcoreworker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
int64_t item_index,
uint64_t attempt_number)
c_string MemoryUsageString()
int GetMemoryStoreSize()

CWorkerContext &GetWorkerContext()
void YieldCurrentFiber(CFiberEvent &coroutine_done)
Expand Down
36 changes: 36 additions & 0 deletions python/ray/tests/test_streaming_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def assert_no_leak():
for rc in ref_counts.values():
assert rc["local"] == 0
assert rc["submitted"] == 0
assert core_worker.get_memory_store_size() == 0


class MockedWorker:
Expand Down Expand Up @@ -1132,6 +1133,41 @@ async def main():
assert 4.5 < time.time() - s < 6.5


def test_no_memory_store_obj_leak(shutdown_only):
"""Fixes https://github.com/ray-project/ray/issues/38089
Verify there's no leak from in-memory object store when
using a streaming generator.
"""
ray.init()

@ray.remote
def f():
for _ in range(10):
yield 1

for _ in range(10):
for ref in f.options(num_returns="streaming").remote():
del ref

time.sleep(0.2)

core_worker = ray._private.worker.global_worker.core_worker
assert core_worker.get_memory_store_size() == 0
assert_no_leak()

for _ in range(10):
for ref in f.options(num_returns="streaming").remote():
break

time.sleep(0.2)

del ref
core_worker = ray._private.worker.global_worker.core_worker
assert core_worker.get_memory_store_size() == 0
assert_no_leak()


if __name__ == "__main__":
import os

Expand Down
2 changes: 2 additions & 0 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
}
}

int GetMemoryStoreSize() { return memory_store_->Size(); }

/// Returns a map of all ObjectIDs currently in scope with a pair of their
/// (local, submitted_task) reference counts. For debugging purposes.
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const;
Expand Down
1 change: 1 addition & 0 deletions src/ray/core_worker/reference_count.cc
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ void ReferenceCounter::DeleteReferenceInternal(ReferenceTable::iterator it,
it->second.on_ref_removed(id);
it->second.on_ref_removed = nullptr;
}

PRINT_REF_COUNT(it);

// Whether it is safe to unpin the value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set<ObjectID> &object_i
absl::flat_hash_set<ObjectID> *plasma_ids_to_delete) {
absl::MutexLock lock(&mu_);
for (const auto &object_id : object_ids) {
RAY_LOG(DEBUG) << "Delete an object from a memory store. ObjectId: " << object_id;
auto it = objects_.find(object_id);
if (it != objects_.end()) {
if (it->second->IsInPlasmaError()) {
Expand All @@ -492,6 +493,7 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set<ObjectID> &object_i
void CoreWorkerMemoryStore::Delete(const std::vector<ObjectID> &object_ids) {
absl::MutexLock lock(&mu_);
for (const auto &object_id : object_ids) {
RAY_LOG(DEBUG) << "Delete an object from a memory store. ObjectId: " << object_id;
auto it = objects_.find(object_id);
if (it != objects_.end()) {
OnDelete(it->second);
Expand Down
22 changes: 14 additions & 8 deletions src/ray/core_worker/task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,29 @@ const int64_t kTaskFailureThrottlingThreshold = 50;
// Throttle task failure logs to once this interval.
const int64_t kTaskFailureLoggingFrequencyMillis = 5000;

std::vector<ObjectID> ObjectRefStream::GetItemsUnconsumed() const {
std::vector<ObjectID> result;
absl::flat_hash_set<ObjectID> ObjectRefStream::GetItemsUnconsumed() const {
absl::flat_hash_set<ObjectID> result;
for (int64_t index = 0; index <= max_index_seen_; index++) {
const auto &object_id = GetObjectRefAtIndex(index);
if (refs_written_to_stream_.find(object_id) == refs_written_to_stream_.end()) {
continue;
}

if (index >= next_index_) {
result.push_back(object_id);
result.emplace(object_id);
}
}

if (end_of_stream_index_ != -1) {
// End of stream index is never consumed by a caller
// so we should add it here.
result.push_back(GetObjectRefAtIndex(end_of_stream_index_));
const auto &object_id = GetObjectRefAtIndex(end_of_stream_index_);
result.emplace(object_id);
}

// Temporarily owned refs are not consumed.
for (const auto &object_id : temporarily_owned_refs_) {
result.push_back(object_id);
result.emplace(object_id);
}
return result;
}
Expand Down Expand Up @@ -428,7 +429,7 @@ bool TaskManager::HandleTaskReturn(const ObjectID &object_id,

void TaskManager::DelObjectRefStream(const ObjectID &generator_id) {
RAY_LOG(DEBUG) << "Deleting an object ref stream of an id " << generator_id;
std::vector<ObjectID> object_ids_unconsumed;
absl::flat_hash_set<ObjectID> object_ids_unconsumed;

{
absl::MutexLock lock(&mu_);
Expand All @@ -441,12 +442,17 @@ void TaskManager::DelObjectRefStream(const ObjectID &generator_id) {
object_ids_unconsumed = stream.GetItemsUnconsumed();
object_ref_streams_.erase(generator_id);
}

// When calling RemoveLocalReference, we shouldn't hold a lock.
for (const auto &object_id : object_ids_unconsumed) {
std::vector<ObjectID> deleted;
RAY_LOG(INFO) << "Removing unconsume streaming ref " << object_id;
RAY_LOG(DEBUG) << "Removing unconsume streaming ref " << object_id;
reference_counter_->RemoveLocalReference(object_id, &deleted);
// TODO(sang): This is required because the reference counter
// cannot remove objects from the in memory store.
// Instead of doing this manually here, we should modify
// reference_count.h to automatically remove objects
// when the ref goes to 0.
in_memory_store_->Delete(deleted);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class ObjectRefStream {
/// Get all the ObjectIDs that are not read yet via TryReadNextItem.
///
/// \return A list of object IDs that are not read yet.
std::vector<ObjectID> GetItemsUnconsumed() const;
absl::flat_hash_set<ObjectID> GetItemsUnconsumed() const;

private:
ObjectID GetObjectRefAtIndex(int64_t generator_index) const;
Expand Down
15 changes: 9 additions & 6 deletions src/ray/core_worker/test/task_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) {

TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) {
/**
* Verify DEL cleans all references and ignore all future WRITE.
* Verify DEL cleans all references/objects and ignore all future WRITE.
*
* CREATE WRITE WRITE DEL (make sure no refs are leaked)
*/
Expand Down Expand Up @@ -1602,6 +1602,8 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) {

// NumObjectIDsInScope == Generator + 2 WRITE
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
// 2 in memory objects.
ASSERT_EQ(store_->Size(), 2);
std::vector<std::shared_ptr<RayObject>> results;
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results));
Expand All @@ -1614,11 +1616,8 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) {
// DELETE. This should clean all references except generator id.
manager_.DelObjectRefStream(generator_id);
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
// Unfortunately, when the obj ref goes out of scope,
// this is called from the language frontend. We mimic this behavior
// by manually calling these APIs.
store_->Delete({dynamic_return_id});
store_->Delete({dynamic_return_id2});
// All the in memory objects should be cleaned up.
ASSERT_EQ(store_->Size(), 0);
ASSERT_TRUE(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results).IsTimedOut());
results.clear();
ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results).IsTimedOut());
Expand All @@ -1640,6 +1639,8 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) {
ASSERT_FALSE(manager_.HandleReportGeneratorItemReturns(req));
// The write should have been no op. No refs and no obj values except the generator id.
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
// All the in memory objects should be cleaned up.
ASSERT_EQ(store_->Size(), 0);
ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, false, &results).IsTimedOut());
results.clear();

Expand Down Expand Up @@ -1741,6 +1742,8 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelOutOfOrder) {

// There must be only a generator ID.
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
// All the objects should be cleaned up.
ASSERT_EQ(store_->Size(), 0);
CompletePendingStreamingTask(spec, caller_address, 0);
}

Expand Down

0 comments on commit 373f5f1

Please sign in to comment.