Skip to content

Commit

Permalink
[core] fix depth of threaded task (ray-project#30902)
Browse files Browse the repository at this point in the history
Signed-off-by: Clarence Ng <[email protected]>

AIR's runs training on a separate RunnerThread that is implemented with threading.Thread.

The current task spec and depth is thread-local, which breaks when a different thread such as the RunnerThread is the one that is starting the task. The RunnerThread starts the training worker and dataset. Without this fix, the depth is incorrectly set to 1. With the fix it is correctly set to 2.

Program this fixes: https://gist.github.com/clarng/50094104476080d1b2d04bded4e370b3

Going forward we should re-evaluate whether the worker context should be thread-local or whether we should clean this up and use something that is thread-safe
  • Loading branch information
clarng authored Dec 7, 2022
1 parent 7728cc7 commit 8b62748
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
29 changes: 29 additions & 0 deletions python/ray/tests/test_nested_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import threading

import pytest

Expand Down Expand Up @@ -150,5 +151,33 @@ def run(self):
ray.get(nested_actor.run.remote())


def test_thread_create_task(shutdown_only):
@ray.remote
def thread_create_task():
assert ray._private.worker.global_worker.task_depth == 1

global has_exception
has_exception = False

@ray.remote
def check_nested_depth():
assert ray._private.worker.global_worker.task_depth == 2

def run_check_nested_depth():
try:
ray.get(check_nested_depth.options(max_retries=0).remote())
except Exception:
global has_exception
has_exception = True

t1 = threading.Thread(target=run_check_nested_depth)
t1.start()
t1.join()

assert not has_exception

ray.get(thread_create_task.remote())


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))
10 changes: 5 additions & 5 deletions src/ray/core_worker/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,8 @@ ObjectIDIndexType WorkerContext::GetNextPutIndex() {
}

int64_t WorkerContext::GetTaskDepth() const {
auto task_spec = GetCurrentTask();
if (task_spec) {
return task_spec->GetDepth();
}
return 0;
absl::ReaderMutexLock lock(&mutex_);
return task_depth_;
}

const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id_; }
Expand Down Expand Up @@ -239,9 +236,12 @@ void WorkerContext::SetCurrentActorId(const ActorID &actor_id) LOCKS_EXCLUDED(mu
current_actor_id_ = actor_id;
}

void WorkerContext::SetTaskDepth(int64_t depth) { task_depth_ = depth; }

void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
absl::WriterMutexLock lock(&mutex_);
GetThreadContext().SetCurrentTask(task_spec);
SetTaskDepth(task_spec.GetDepth());
RAY_CHECK(current_job_id_ == task_spec.JobId());
if (task_spec.IsNormalTask()) {
current_task_is_direct_call_ = true;
Expand Down
3 changes: 3 additions & 0 deletions src/ray/core_worker/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class WorkerContext {

void SetCurrentActorId(const ActorID &actor_id) LOCKS_EXCLUDED(mutex_);

void SetTaskDepth(int64_t depth) EXCLUSIVE_LOCKS_REQUIRED(mutex_);

void SetCurrentTask(const TaskSpecification &task_spec) LOCKS_EXCLUDED(mutex_);

void ResetCurrentTask();
Expand Down Expand Up @@ -103,6 +105,7 @@ class WorkerContext {
const WorkerType worker_type_;
const WorkerID worker_id_;
const JobID current_job_id_;
int64_t task_depth_ GUARDED_BY(mutex_) = 0;
ActorID current_actor_id_ GUARDED_BY(mutex_);
int current_actor_max_concurrency_ GUARDED_BY(mutex_) = 1;
bool current_actor_is_asyncio_ GUARDED_BY(mutex_) = false;
Expand Down

0 comments on commit 8b62748

Please sign in to comment.