Skip to content

Commit

Permalink
Call suspend user function on the calling thread (uxlfoundation#727)
Browse files Browse the repository at this point in the history
Signed-off-by: pavelkumbrasev <[email protected]>
Co-authored-by: Alex <[email protected]>
  • Loading branch information
pavelkumbrasev and alexey-katranov authored Feb 18, 2022
1 parent 6666292 commit e77098d
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 112 deletions.
2 changes: 1 addition & 1 deletion src/tbb/arena.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2005-2021 Intel Corporation
Copyright (c) 2005-2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
60 changes: 60 additions & 0 deletions src/tbb/scheduler_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,54 @@ struct suspend_point_type {
bool m_is_critical{ false };
//! Associated coroutine
co_context m_co_context;
//! Supend point before resume
suspend_point_type* m_prev_suspend_point{nullptr};

// Possible state transitions:
// A -> S -> N -> A
// A -> N -> S -> N -> A
enum class stack_state {
active, // some thread is working with this stack
suspended, // no thread is working with this stack
notified // some thread tried to resume this stack
};

//! The flag required to protect suspend finish and resume call
std::atomic<stack_state> m_stack_state{stack_state::active};

void resume(suspend_point_type* sp) {
__TBB_ASSERT(m_stack_state.load(std::memory_order_relaxed) != stack_state::suspended, "The stack is expected to be active");

sp->m_prev_suspend_point = this;

// Do not access sp after resume
m_co_context.resume(sp->m_co_context);
__TBB_ASSERT(m_stack_state.load(std::memory_order_relaxed) != stack_state::active, nullptr);

finilize_resume();
}

void finilize_resume() {
m_stack_state.store(stack_state::active, std::memory_order_relaxed);
// Set the suspended state for the stack that we left. If the state is already notified, it means that
// someone already tried to resume our previous stack but failed. So, we need to resume it.
// m_prev_suspend_point might be nullptr when destroying co_context based on threads
if (m_prev_suspend_point && m_prev_suspend_point->m_stack_state.exchange(stack_state::suspended) == stack_state::notified) {
r1::resume(m_prev_suspend_point);
}
m_prev_suspend_point = nullptr;
}

bool try_notify_resume() {
// Check that stack is already suspended. Return false if not yet.
return m_stack_state.exchange(stack_state::notified) == stack_state::suspended;
}

void recall_owner() {
__TBB_ASSERT(m_stack_state.load(std::memory_order_relaxed) == stack_state::suspended, nullptr);
m_stack_state.store(stack_state::notified, std::memory_order_relaxed);
m_is_owner_recalled.store(true, std::memory_order_release);
}

struct resume_task final : public d1::task {
task_dispatcher& m_target;
Expand Down Expand Up @@ -385,6 +433,15 @@ class alignas (max_nfs_size) task_dispatcher {
friend class delegated_task;
friend struct base_waiter;

//! The list of possible post resume actions.
enum class post_resume_action {
invalid,
register_waiter,
cleanup,
notify,
none
};

//! The data of the current thread attached to this task_dispatcher
thread_data* m_thread_data{ nullptr };

Expand Down Expand Up @@ -475,6 +532,9 @@ class alignas (max_nfs_size) task_dispatcher {
#if __TBB_RESUMABLE_TASKS
/* [[noreturn]] */ void co_local_wait_for_all() noexcept;
void suspend(suspend_callback_type suspend_callback, void* user_callback);
void internal_suspend();
void do_post_resume_action();

bool resume(task_dispatcher& target);
suspend_point_type* get_suspend_point();
void init_suspend_point(arena* a, std::size_t stack_size);
Expand Down
115 changes: 55 additions & 60 deletions src/tbb/task.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2005-2021 Intel Corporation
Copyright (c) 2005-2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,29 +47,29 @@ void suspend(suspend_callback_type suspend_callback, void* user_callback) {
void resume(suspend_point_type* sp) {
assert_pointers_valid(sp, sp->m_arena);
task_dispatcher& task_disp = sp->m_resume_task.m_target;
__TBB_ASSERT(task_disp.m_thread_data == nullptr, nullptr);

// TODO: remove this work-around
// Prolong the arena's lifetime while all coroutines are alive
// (otherwise the arena can be destroyed while some tasks are suspended).
arena& a = *sp->m_arena;
a.my_references += arena::ref_external;

if (task_disp.m_properties.critical_task_allowed) {
// The target is not in the process of executing critical task, so the resume task is not critical.
a.my_resume_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
} else {
#if __TBB_PREVIEW_CRITICAL_TASKS
// The target is in the process of executing critical task, so the resume task is critical.
a.my_critical_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
#endif
if (sp->try_notify_resume()) {
// TODO: remove this work-around
// Prolong the arena's lifetime while all coroutines are alive
// (otherwise the arena can be destroyed while some tasks are suspended).
arena& a = *sp->m_arena;
a.my_references += arena::ref_external;

if (task_disp.m_properties.critical_task_allowed) {
// The target is not in the process of executing critical task, so the resume task is not critical.
a.my_resume_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
} else {
#if __TBB_PREVIEW_CRITICAL_TASKS
// The target is in the process of executing critical task, so the resume task is critical.
a.my_critical_task_stream.push(&sp->m_resume_task, random_lane_selector(sp->m_random));
#endif
}
// Do not access target after that point.
a.advertise_new_work<arena::wakeup>();
// Release our reference to my_arena.
a.on_thread_leaving<arena::ref_external>();
}

// Do not access target after that point.
a.advertise_new_work<arena::wakeup>();

// Release our reference to my_arena.
a.on_thread_leaving<arena::ref_external>();
}

suspend_point_type* current_suspend_point() {
Expand All @@ -92,9 +92,7 @@ static task_dispatcher& create_coroutine(thread_data& td) {
return *task_disp;
}

void task_dispatcher::suspend(suspend_callback_type suspend_callback, void* user_callback) {
__TBB_ASSERT(suspend_callback != nullptr, nullptr);
__TBB_ASSERT(user_callback != nullptr, nullptr);
void task_dispatcher::internal_suspend() {
__TBB_ASSERT(m_thread_data != nullptr, nullptr);

arena_slot* slot = m_thread_data->my_arena_slot;
Expand All @@ -105,24 +103,31 @@ void task_dispatcher::suspend(suspend_callback_type suspend_callback, void* user
bool is_recalled = default_task_disp.get_suspend_point()->m_is_owner_recalled.load(std::memory_order_acquire);
task_dispatcher& target = is_recalled ? default_task_disp : create_coroutine(*m_thread_data);

thread_data::suspend_callback_wrapper callback = { suspend_callback, user_callback, get_suspend_point() };
m_thread_data->set_post_resume_action(thread_data::post_resume_action::callback, &callback);
resume(target);

if (m_properties.outermost) {
recall_point();
}
}

void task_dispatcher::suspend(suspend_callback_type suspend_callback, void* user_callback) {
__TBB_ASSERT(suspend_callback != nullptr, nullptr);
__TBB_ASSERT(user_callback != nullptr, nullptr);
suspend_callback(user_callback, get_suspend_point());

__TBB_ASSERT(m_thread_data != nullptr, nullptr);
__TBB_ASSERT(m_thread_data->my_post_resume_action == post_resume_action::none, nullptr);
__TBB_ASSERT(m_thread_data->my_post_resume_arg == nullptr, nullptr);
internal_suspend();
}

bool task_dispatcher::resume(task_dispatcher& target) {
// Do not create non-trivial objects on the stack of this function. They might never be destroyed
{
thread_data* td = m_thread_data;
__TBB_ASSERT(&target != this, "We cannot resume to ourself");
__TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
__TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
__TBB_ASSERT(td->my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set");
__TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");

// Change the task dispatcher
td->detach_task_dispatcher();
Expand All @@ -131,13 +136,14 @@ bool task_dispatcher::resume(task_dispatcher& target) {
__TBB_ASSERT(m_suspend_point != nullptr, "Suspend point must be created");
__TBB_ASSERT(target.m_suspend_point != nullptr, "Suspend point must be created");
// Swap to the target coroutine.
m_suspend_point->m_co_context.resume(target.m_suspend_point->m_co_context);

m_suspend_point->resume(target.m_suspend_point);
// Pay attention that m_thread_data can be changed after resume
if (m_thread_data) {
thread_data* td = m_thread_data;
__TBB_ASSERT(td != nullptr, "This task dispatcher must be attach to a thread data");
__TBB_ASSERT(td->my_task_dispatcher == this, "Thread data must be attached to this task dispatcher");
td->do_post_resume_action();
do_post_resume_action();

// Remove the recall flag if the thread in its original task dispatcher
arena_slot* slot = td->my_arena_slot;
Expand All @@ -151,54 +157,43 @@ bool task_dispatcher::resume(task_dispatcher& target) {
return false;
}

void thread_data::do_post_resume_action() {
__TBB_ASSERT(my_post_resume_action != thread_data::post_resume_action::none, "The post resume action must be set");
__TBB_ASSERT(my_post_resume_arg, "The post resume action must have an argument");

switch (my_post_resume_action) {
void task_dispatcher::do_post_resume_action() {
thread_data* td = m_thread_data;
switch (td->my_post_resume_action) {
case post_resume_action::register_waiter:
{
static_cast<market_concurrent_monitor::resume_context*>(my_post_resume_arg)->notify();
break;
}
case post_resume_action::resume:
{
r1::resume(static_cast<suspend_point_type*>(my_post_resume_arg));
break;
}
case post_resume_action::callback:
{
suspend_callback_wrapper callback = *static_cast<suspend_callback_wrapper*>(my_post_resume_arg);
callback();
__TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");
static_cast<market_concurrent_monitor::resume_context*>(td->my_post_resume_arg)->notify();
break;
}
case post_resume_action::cleanup:
{
task_dispatcher* to_cleanup = static_cast<task_dispatcher*>(my_post_resume_arg);
__TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");
task_dispatcher* to_cleanup = static_cast<task_dispatcher*>(td->my_post_resume_arg);
// Release coroutine's reference to my_arena
my_arena->on_thread_leaving<arena::ref_external>();
td->my_arena->on_thread_leaving<arena::ref_external>();
// Cache the coroutine for possible later re-usage
my_arena->my_co_cache.push(to_cleanup);
td->my_arena->my_co_cache.push(to_cleanup);
break;
}
case post_resume_action::notify:
{
suspend_point_type* sp = static_cast<suspend_point_type*>(my_post_resume_arg);
sp->m_is_owner_recalled.store(true, std::memory_order_release);
// Do not access sp because it can be destroyed after the store
__TBB_ASSERT(td->my_post_resume_arg, "The post resume action must have an argument");
suspend_point_type* sp = static_cast<suspend_point_type*>(td->my_post_resume_arg);
sp->recall_owner();
// Do not access sp because it can be destroyed after recall

auto is_our_suspend_point = [sp](market_context ctx) {
return std::uintptr_t(sp) == ctx.my_uniq_addr;
auto is_our_suspend_point = [sp] (market_context ctx) {
return std::uintptr_t(sp) == ctx.my_uniq_addr;
};
my_arena->my_market->get_wait_list().notify(is_our_suspend_point);
td->my_arena->my_market->get_wait_list().notify(is_our_suspend_point);
break;
}
default:
__TBB_ASSERT(false, "Unknown post resume action");
__TBB_ASSERT(td->my_post_resume_action == post_resume_action::none, "Unknown post resume action");
__TBB_ASSERT(td->my_post_resume_arg == nullptr, "The post resume argument should not be set");
}

my_post_resume_action = post_resume_action::none;
my_post_resume_arg = nullptr;
td->clear_post_resume_action();
}

#else
Expand Down
9 changes: 5 additions & 4 deletions src/tbb/task_dispatcher.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2020-2021 Intel Corporation
Copyright (c) 2020-2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -205,8 +205,9 @@ void task_dispatcher::execute_and_wait(d1::task* t, d1::wait_context& wait_ctx,
// Do not create non-trivial objects on the stack of this function. They will never be destroyed.
assert_pointer_valid(m_thread_data);

m_suspend_point->finilize_resume();
// Basically calls the user callback passed to the tbb::task::suspend function
m_thread_data->do_post_resume_action();
do_post_resume_action();

// Endless loop here because coroutine could be reused
d1::task* resume_task{};
Expand All @@ -217,8 +218,8 @@ void task_dispatcher::execute_and_wait(d1::task* t, d1::wait_context& wait_ctx,
assert_task_valid(resume_task);
__TBB_ASSERT(this == m_thread_data->my_task_dispatcher, nullptr);

m_thread_data->set_post_resume_action(thread_data::post_resume_action::cleanup, this);
m_thread_data->set_post_resume_action(post_resume_action::cleanup, this);

} while (resume(static_cast<suspend_point_type::resume_task*>(resume_task)->m_target));
// This code might be unreachable
}
Expand Down
17 changes: 6 additions & 11 deletions src/tbb/task_dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ inline d1::task* suspend_point_type::resume_task::execute(d1::execution_data& ed
// The wait_ctx is present only in external_waiter. In that case we leave the current stack
// in the abandoned state to resume when waiting completes.
thread_data* td = ed_ext.task_disp->m_thread_data;
td->set_post_resume_action(thread_data::post_resume_action::register_waiter, &monitor_node);
td->set_post_resume_action(task_dispatcher::post_resume_action::register_waiter, &monitor_node);

market_concurrent_monitor& wait_list = td->my_arena->my_market->get_wait_list();

Expand All @@ -78,11 +78,11 @@ inline d1::task* suspend_point_type::resume_task::execute(d1::execution_data& ed
}

td->clear_post_resume_action();
td->set_post_resume_action(thread_data::post_resume_action::resume, ed_ext.task_disp->get_suspend_point());
r1::resume(ed_ext.task_disp->get_suspend_point());
} else {
// If wait_ctx is null, it can be only a worker thread on outermost level because
// coroutine_waiter interrupts bypass loop before the resume_task execution.
ed_ext.task_disp->m_thread_data->set_post_resume_action(thread_data::post_resume_action::notify,
ed_ext.task_disp->m_thread_data->set_post_resume_action(task_dispatcher::post_resume_action::notify,
ed_ext.task_disp->get_suspend_point());
}
// Do not access this task because it might be destroyed
Expand Down Expand Up @@ -380,14 +380,9 @@ inline void task_dispatcher::recall_point() {
if (this != &m_thread_data->my_arena_slot->default_task_dispatcher()) {
__TBB_ASSERT(m_suspend_point != nullptr, nullptr);
__TBB_ASSERT(m_suspend_point->m_is_owner_recalled.load(std::memory_order_relaxed) == false, nullptr);
d1::suspend([](suspend_point_type* sp) {
sp->m_is_owner_recalled.store(true, std::memory_order_release);
auto is_related_suspend_point = [sp] (market_context context) {
std::uintptr_t sp_addr = std::uintptr_t(sp);
return sp_addr == context.my_uniq_addr;
};
sp->m_arena->my_market->get_wait_list().notify(is_related_suspend_point);
});

m_thread_data->set_post_resume_action(post_resume_action::notify, get_suspend_point());
internal_suspend();

if (m_thread_data->my_inbox.is_idle_state(true)) {
m_thread_data->my_inbox.set_is_idle(false);
Expand Down
1 change: 0 additions & 1 deletion src/tbb/task_group_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ void task_group_context_impl::bind_to_impl(d1::task_group_context& ctx, thread_d
}

void task_group_context_impl::bind_to(d1::task_group_context& ctx, thread_data* td) {
__TBB_ASSERT(!is_poisoned(ctx.my_context_list), nullptr);
d1::task_group_context::state state = ctx.my_state.load(std::memory_order_acquire);
if (state <= d1::task_group_context::state::locked) {
if (state == d1::task_group_context::state::created &&
Expand Down
Loading

0 comments on commit e77098d

Please sign in to comment.