Skip to content

Commit

Permalink
Better work per thread calculation.
Browse files Browse the repository at this point in the history
  • Loading branch information
novoselrok committed Jul 28, 2020
1 parent 1297412 commit 4613202
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions src/annoylib.h
Original file line number Diff line number Diff line change
Expand Up @@ -972,19 +972,12 @@ template<typename S, typename T, typename Distance, typename Random>

_n_nodes = _n_items;

std::mutex _nodes_mutex;
std::mutex nodes_mutex;
ThreadBarrier barrier(n_threads);
vector<std::thread> threads(n_threads);
int work_per_thread = (int)floor(q / (double)n_threads);
int work_remainder = q % n_threads;
for (int i = 0; i < n_threads; i++) {
int trees_per_thread = -1;
if (q > -1) {
// First thread picks up the remainder of the work
trees_per_thread = i == 0 ? work_per_thread + work_remainder : work_per_thread;
}

threads[i] = std::thread(&AnnoyIndex<S, T, D, Random>::_thread_build, this, trees_per_thread, i, std::ref(barrier), std::ref(_nodes_mutex));
int trees_per_thread = q == -1 ? -1 : (int)floor((q + i) / n_threads);
threads[i] = std::thread(&AnnoyIndex<S, T, D, Random>::_thread_build, this, trees_per_thread, i, std::ref(barrier), std::ref(nodes_mutex));
}

for (auto& thread : threads) {
Expand Down Expand Up @@ -1206,7 +1199,7 @@ template<typename S, typename T, typename Distance, typename Random>
return get_node_ptr<S, Node>(_nodes, _s, i);
}

void _thread_build(int q, int thread_idx, ThreadBarrier& barrier, std::mutex& _nodes_mutex) {
void _thread_build(int q, int thread_idx, ThreadBarrier& barrier, std::mutex& nodes_mutex) {
Random _random;
// Each thread needs its own seed, otherwise each thread would be building the same tree(s)
int seed = _is_seeded ? _seed + thread_idx : thread_idx;
Expand Down Expand Up @@ -1250,7 +1243,7 @@ template<typename S, typename T, typename Distance, typename Random>
// Wait for all threads to finish before we can start inserting tree nodes into global _nodes array
barrier.wait();

_nodes_mutex.lock();
nodes_mutex.lock();
// When a thread wants to insert local tree nodes into global _nodes it has to stop pretending that there is
// going to be only one tree. Each thread has to update all split nodes children that are pointing to other split nodes
// because their indices will change once inserted into global _nodes.
Expand Down Expand Up @@ -1283,7 +1276,7 @@ template<typename S, typename T, typename Distance, typename Random>
thread_roots[tree_idx] += split_nodes_offset;
}
_roots.insert(_roots.end(), thread_roots.begin(), thread_roots.end());
_nodes_mutex.unlock();
nodes_mutex.unlock();
}

S _make_tree(const vector<S >& indices, vector<Node* >& split_nodes, bool is_root, Random& _random) {
Expand Down

0 comments on commit 4613202

Please sign in to comment.