Skip to content

Commit

Permalink
Fix syncing mechanism in raft-ann-bench C++ search (rapidsai#1961)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1961
  • Loading branch information
divyegala authored Nov 5, 2023
1 parent aa3e229 commit bafd2a8
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <benchmark/benchmark.h>

#include <algorithm>
#include <atomic>
#include <chrono>
#include <cmath>
#include <condition_variable>
Expand All @@ -39,6 +40,7 @@ namespace raft::bench::ann {

std::mutex init_mutex;
std::condition_variable cond_var;
std::atomic_int processed_threads{0};

static inline std::unique_ptr<AnnBase> current_algo{nullptr};
static inline std::shared_ptr<AlgoProperty> current_algo_props{nullptr};
Expand Down Expand Up @@ -199,6 +201,7 @@ void bench_search(::benchmark::State& state,
*/
if (state.thread_index() == 0) {
std::unique_lock lk(init_mutex);
cond_var.wait(lk, [] { return processed_threads.load(std::memory_order_acquire) == 0; });
// algo is static to cache it between close search runs to save time on index loading
static std::string index_file = "";
if (index.file != index_file) {
Expand Down Expand Up @@ -247,13 +250,14 @@ void bench_search(::benchmark::State& state,
}

query_set = dataset->query_set(current_algo_props->query_memory_type);
processed_threads.store(state.threads(), std::memory_order_acq_rel);
cond_var.notify_all();
} else {
std::unique_lock lk(init_mutex);
// All other threads will wait for the first thread to initialize the algo.

cond_var.wait(
lk, [] { return current_algo_props.get() != nullptr && current_algo.get() != nullptr; });
cond_var.wait(lk, [&state] {
return processed_threads.load(std::memory_order_acquire) == state.threads();
});
// gbench ensures that all threads are synchronized at the start of the benchmark loop.
// We are accessing shared variables (like current_algo, current_algo_probs) before the
// benchmark loop, therefore the synchronization here is necessary.
Expand Down Expand Up @@ -315,6 +319,10 @@ void bench_search(::benchmark::State& state,

if (state.skipped()) { return; }

// assume thread has finished processing successfully at this point
// last thread to finish processing notifies all
if (processed_threads-- == 0) { cond_var.notify_all(); }

// Use the last thread as a sanity check that all the threads are working.
if (state.thread_index() == state.threads() - 1) {
// evaluate recall
Expand Down

0 comments on commit bafd2a8

Please sign in to comment.