Skip to content

Commit

Permalink
IVF-PQ: Fix illegal memory access with large max_samples (rapidsai#1685)
Browse files Browse the repository at this point in the history
PR rapidsai#1356 increased the maximum internal batch size of IVF-PQ and, by doing so, exposed a bug of uint32_t integer overflow that resulted in incorrectly allocated intermediate buffers.
This PR fixes the original bug, but also:
  - Proofs the places where `max_samples` multiplied by something could cause integer overflow
  - Make more careful estimation of the workspace size to avoid `out_of_memory` errors from the limiting resource adaptor
  - Removes an unused argument from `compute_similarity_kernel` which has been slipping between code updates for a really long time

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1685
  • Loading branch information
achirkin authored Jul 27, 2023
1 parent a20f497 commit 4574e9a
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ template <typename OutT,
int Capacity,
bool PrecompBaseDiff,
bool EnableSMemLut>
__global__ void compute_similarity_kernel(uint32_t n_rows,
uint32_t dim,
__global__ void compute_similarity_kernel(uint32_t dim,
uint32_t n_probes,
uint32_t pq_dim,
uint32_t n_queries,
Expand Down Expand Up @@ -82,7 +81,6 @@ struct selected {
template <typename OutT, typename LutT, typename IvfSampleFilterT>
void compute_similarity_run(selected<OutT, LutT, IvfSampleFilterT> s,
rmm::cuda_stream_view stream,
uint32_t n_rows,
uint32_t dim,
uint32_t n_probes,
uint32_t pq_dim,
Expand Down Expand Up @@ -156,7 +154,6 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim,
* Setting this to `false` allows to reduce the shared memory usage (and maximum data dim)
* at the cost of reducing global memory reading throughput.
*
* @param n_rows the number of records in the dataset
* @param dim the dimensionality of the data (NB: after rotation transform, i.e. `index.rot_dim()`).
* @param n_probes the number of clusters to search for each query
* @param pq_dim
Expand Down Expand Up @@ -251,8 +250,7 @@ template <typename OutT,
int Capacity,
bool PrecompBaseDiff,
bool EnableSMemLut>
__global__ void compute_similarity_kernel(uint32_t n_rows,
uint32_t dim,
__global__ void compute_similarity_kernel(uint32_t dim,
uint32_t n_probes,
uint32_t pq_dim,
uint32_t n_queries,
Expand Down Expand Up @@ -327,14 +325,15 @@ __global__ void compute_similarity_kernel(uint32_t n_rows,
uint32_t* out_indices = nullptr;
if constexpr (kManageLocalTopK) {
// Store topk calculated distances to out_scores (and its indices to out_indices)
out_scores = _out_scores + topk * (probe_ix + (n_probes * query_ix));
out_indices = _out_indices + topk * (probe_ix + (n_probes * query_ix));
const uint64_t out_offset = probe_ix + n_probes * query_ix;
out_scores = _out_scores + out_offset * topk;
out_indices = _out_indices + out_offset * topk;
} else {
// Store all calculated distances to out_scores
out_scores = _out_scores + max_samples * query_ix;
out_scores = _out_scores + uint64_t(max_samples) * query_ix;
}
uint32_t label = cluster_labels[n_probes * query_ix + probe_ix];
const float* cluster_center = cluster_centers + (dim * label);
const float* cluster_center = cluster_centers + dim * label;
const float* pq_center;
if (codebook_kind == codebook_gen::PER_SUBSPACE) {
pq_center = pq_centers;
Expand Down Expand Up @@ -602,7 +601,6 @@ template <typename OutT,
typename IvfSampleFilterT = raft::neighbors::filtering::none_ivf_sample_filter>
void compute_similarity_run(selected<OutT, LutT, IvfSampleFilterT> s,
rmm::cuda_stream_view stream,
uint32_t n_rows,
uint32_t dim,
uint32_t n_probes,
uint32_t pq_dim,
Expand All @@ -625,8 +623,7 @@ void compute_similarity_run(selected<OutT, LutT, IvfSampleFilterT> s,
OutT* _out_scores,
uint32_t* _out_indices)
{
s.kernel<<<s.grid_dim, s.block_dim, s.smem_size, stream>>>(n_rows,
dim,
s.kernel<<<s.grid_dim, s.block_dim, s.smem_size, stream>>>(dim,
n_probes,
pq_dim,
n_queries,
Expand Down
29 changes: 17 additions & 12 deletions cpp/include/raft/neighbors/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,10 @@ void ivfpq_search_worker(raft::resources const& handle,
auto stream = resource::get_cuda_stream(handle);
auto mr = resource::get_workspace_resource(handle);

bool manage_local_topk = is_local_topk_feasible(topK, n_probes, n_queries);
auto topk_len = manage_local_topk ? n_probes * topK : max_samples;
bool manage_local_topk = is_local_topk_feasible(topK, n_probes, n_queries);
auto topk_len = manage_local_topk ? n_probes * topK : max_samples;
std::size_t n_queries_probes = std::size_t(n_queries) * std::size_t(n_probes);
std::size_t n_queries_topk_len = std::size_t(n_queries) * std::size_t(topk_len);
if (manage_local_topk) {
RAFT_LOG_DEBUG("Fused version of the search kernel is selected (manage_local_topk == true)");
} else {
Expand All @@ -448,13 +450,13 @@ void ivfpq_search_worker(raft::resources const& handle,
rmm::device_uvector<uint32_t> index_list_sorted_buf(0, stream, mr);
uint32_t* index_list_sorted = nullptr;
rmm::device_uvector<uint32_t> num_samples(n_queries, stream, mr);
rmm::device_uvector<uint32_t> chunk_index(n_queries * n_probes, stream, mr);
rmm::device_uvector<uint32_t> chunk_index(n_queries_probes, stream, mr);
// [maxBatchSize, max_samples] or [maxBatchSize, n_probes, topk]
rmm::device_uvector<ScoreT> distances_buf(n_queries * topk_len, stream, mr);
rmm::device_uvector<ScoreT> distances_buf(n_queries_topk_len, stream, mr);
rmm::device_uvector<uint32_t> neighbors_buf(0, stream, mr);
uint32_t* neighbors_ptr = nullptr;
if (manage_local_topk) {
neighbors_buf.resize(n_queries * topk_len, stream);
neighbors_buf.resize(n_queries_topk_len, stream);
neighbors_ptr = neighbors_buf.data();
}
rmm::device_uvector<uint32_t> neighbors_uint32_buf(0, stream, mr);
Expand All @@ -479,10 +481,10 @@ void ivfpq_search_worker(raft::resources const& handle,
// The goal is to incrase the L2 cache hit rate to read the vectors
// of a cluster by processing the cluster at the same time as much as
// possible.
index_list_sorted_buf.resize(n_queries * n_probes, stream);
index_list_sorted_buf.resize(n_queries_probes, stream);
auto index_list_buf =
make_device_mdarray<uint32_t>(handle, mr, make_extents<uint32_t>(n_queries * n_probes));
rmm::device_uvector<uint32_t> cluster_labels_out(n_queries * n_probes, stream, mr);
make_device_mdarray<uint32_t>(handle, mr, make_extents<uint32_t>(n_queries_probes));
rmm::device_uvector<uint32_t> cluster_labels_out(n_queries_probes, stream, mr);
auto index_list = index_list_buf.data_handle();
index_list_sorted = index_list_sorted_buf.data();

Expand All @@ -497,7 +499,7 @@ void ivfpq_search_worker(raft::resources const& handle,
cluster_labels_out.data(),
index_list,
index_list_sorted,
n_queries * n_probes,
n_queries_probes,
begin_bit,
end_bit,
stream);
Expand All @@ -508,7 +510,7 @@ void ivfpq_search_worker(raft::resources const& handle,
cluster_labels_out.data(),
index_list,
index_list_sorted,
n_queries * n_probes,
n_queries_probes,
begin_bit,
end_bit,
stream);
Expand Down Expand Up @@ -558,7 +560,6 @@ void ivfpq_search_worker(raft::resources const& handle,
}
compute_similarity_run(search_instance,
stream,
index.size(),
index.rot_dim(),
n_probes,
index.pq_dim(),
Expand Down Expand Up @@ -706,7 +707,11 @@ inline auto get_max_batch_size(raft::resources const& res,
}
// Check in the tmp distance buffer is not too big
auto ws_size = [k, n_probes, max_samples](uint32_t bs) -> uint64_t {
return uint64_t(is_local_topk_feasible(k, n_probes, bs) ? k * n_probes : max_samples) * bs;
const uint64_t buffers_fused = 12ull * k * n_probes;
const uint64_t buffers_non_fused = 4ull * max_samples;
const uint64_t other = 32ull * n_probes;
return static_cast<uint64_t>(bs) *
(other + (is_local_topk_feasible(k, n_probes, bs) ? buffers_fused : buffers_non_fused));
};
auto max_ws_size = resource::get_workspace_free_bytes(res);
if (ws_size(max_batch_size) > max_ws_size) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
template void raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \\
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \\
rmm::cuda_stream_view stream, \\
uint32_t n_rows, \\
uint32_t dim, \\
uint32_t n_probes, \\
uint32_t pq_dim, \\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, IvfSampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, IvfSampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
Expand Down

0 comments on commit 4574e9a

Please sign in to comment.