Skip to content

Commit

Permalink
Merge pull request rapidsai#1992 from rapidsai/branch-23.12
Browse files Browse the repository at this point in the history
Forward-merge branch-23.12 to branch-24.02
  • Loading branch information
GPUtester authored Nov 14, 2023
2 parents 06979df + 77bc461 commit e875e04
Show file tree
Hide file tree
Showing 5 changed files with 337 additions and 397 deletions.
84 changes: 22 additions & 62 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace raft::matrix::detail {

// this is a subset of algorithms, chosen by running the algorithm_selection
// notebook in cpp/scripts/heuristics/select_k
enum class Algo { kRadix11bits, kWarpDistributedShm, kFaissBlockSelect };
enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bitsExtraPass };

/**
* Predict the fastest select_k algorithm based on the number of rows/cols/k
Expand All @@ -50,73 +50,29 @@ enum class Algo { kRadix11bits, kWarpDistributedShm, kFaissBlockSelect };
*/
inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)
{
if (k > 134) {
if (k > 256) {
if (k > 809) {
return Algo::kRadix11bits;
} else {
if (rows > 124) {
if (cols > 63488) {
return Algo::kFaissBlockSelect;
} else {
return Algo::kRadix11bits;
}
} else {
return Algo::kRadix11bits;
}
}
} else {
if (cols > 678736) {
return Algo::kWarpDistributedShm;
if (k > 256) {
if (cols > 16862) {
if (rows > 1020) {
return Algo::kRadix11bitsExtraPass;
} else {
return Algo::kRadix11bits;
}
} else {
return Algo::kRadix11bitsExtraPass;
}
} else {
if (cols > 13776) {
if (rows > 335) {
if (k > 1) {
if (rows > 546) {
return Algo::kWarpDistributedShm;
} else {
if (k > 17) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kFaissBlockSelect;
}
}
} else {
return Algo::kFaissBlockSelect;
}
if (k > 2) {
if (cols > 22061) {
return Algo::kWarpDistributedShm;
} else {
if (k > 44) {
if (cols > 1031051) {
return Algo::kWarpDistributedShm;
} else {
if (rows > 22) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kRadix11bits;
}
}
} else {
return Algo::kWarpDistributedShm;
}
}
} else {
if (k > 1) {
if (rows > 188) {
if (rows > 198) {
return Algo::kWarpDistributedShm;
} else {
if (k > 72) {
return Algo::kRadix11bits;
} else {
return Algo::kWarpDistributedShm;
}
return Algo::kWarpImmediate;
}
} else {
return Algo::kFaissBlockSelect;
}
} else {
return Algo::kWarpImmediate;
}
}
}
Expand Down Expand Up @@ -294,6 +250,8 @@ void select_k(raft::resources const& handle,

switch (algo) {
case Algo::kRadix11bits:
case Algo::kRadix11bitsExtraPass: {
bool fused_last_filter = algo == Algo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
Expand All @@ -302,7 +260,7 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
true, // fused_last_filter
fused_last_filter,
stream,
mr);

Expand All @@ -324,13 +282,15 @@ void select_k(raft::resources const& handle,
handle, raft::make_const_mdspan(offsets.view()), keys, vals, select_min);
}
return;
}
case Algo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case Algo::kFaissBlockSelect:
return neighbors::detail::select_k(
in_val, in_idx, batch_size, len, out_val, out_idx, select_min, k, stream);
case Algo::kWarpImmediate:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_immediate>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
default: RAFT_FAIL("K-selection Algorithm not supported.");
}
}
Expand Down
Loading

0 comments on commit e875e04

Please sign in to comment.