Skip to content

Commit

Permalink
Remove faiss dependency from fused_l2_knn.cuh, selection_faiss.cuh, b…
Browse files Browse the repository at this point in the history
…all_cover.cuh and haversine_distance.cuh (rapidsai#1108)

Remove the dependency on faiss from the fused_l2_knn.cuh, selection_faiss.cuh, ball_cover.cuh and haversine_distance.cuh headers.

This takes a copy of the faiss BlockSelect/WarpSelect device code for top-k selection, and updates to use
raft primitives for things like reductions,  KeyValuePair, warp shuffling etc.

Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Ray Douglass (https://github.com/raydouglass)

URL: rapidsai#1108
  • Loading branch information
benfred authored Jan 5, 2023
1 parent 96578a1 commit 2dd9abb
Showing 16 changed files with 1,216 additions and 287 deletions.
4 changes: 2 additions & 2 deletions ci/checks/copyright.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -37,7 +37,7 @@
re.compile(r"setup[.]cfg$"),
re.compile(r"meta[.]yaml$")
]
ExemptFiles = ["cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh"]
ExemptFiles = ["cpp/include/raft/spatial/knn/detail/faiss_select/"]

# this will break starting at year 10000, which is probably OK :)
CheckSimple = re.compile(
25 changes: 24 additions & 1 deletion cpp/include/raft/core/kvp.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@

#ifdef _RAFT_HAS_CUDA
#include <cub/cub.cuh>
#include <raft/util/cuda_utils.cuh>
#endif
namespace raft {
/**
@@ -58,5 +59,27 @@ struct KeyValuePair {
{
return (value != b.value) || (key != b.key);
}

RAFT_INLINE_FUNCTION bool operator<(const KeyValuePair<_Key, _Value>& b) const
{
return (key < b.key) || ((key == b.key) && value < b.value);
}

RAFT_INLINE_FUNCTION bool operator>(const KeyValuePair<_Key, _Value>& b) const
{
return (key > b.key) || ((key == b.key) && value > b.value);
}
};

#ifdef _RAFT_HAS_CUDA
template <typename _Key, typename _Value>
RAFT_INLINE_FUNCTION KeyValuePair<_Key, _Value> shfl_xor(const KeyValuePair<_Key, _Value>& input,
int laneMask,
int width = WarpSize,
uint32_t mask = 0xffffffffu)
{
return KeyValuePair<_Key, _Value>(shfl_xor(input.key, laneMask, width, mask),
shfl_xor(input.value, laneMask, width, mask));
}
#endif
} // end namespace raft
7 changes: 3 additions & 4 deletions cpp/include/raft/spatial/knn/detail/ball_cover.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -21,7 +21,6 @@
#include "../ball_cover_types.hpp"
#include "ball_cover/common.cuh"
#include "ball_cover/registers.cuh"
#include "block_select_faiss.cuh"
#include "haversine_distance.cuh"
#include "knn_brute_force_faiss.cuh"
#include "selection_faiss.cuh"
@@ -31,15 +30,15 @@

#include <raft/util/cuda_utils.cuh>

#include <raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh>

#include <raft/matrix/matrix.cuh>
#include <raft/random/rng.cuh>
#include <raft/sparse/convert/csr.cuh>

#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <faiss/gpu/utils/Select.cuh>

#include <thrust/fill.h>
#include <thrust/for_each.h>
#include <thrust/functional.h>
57 changes: 27 additions & 30 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,7 +19,7 @@
#include "common.cuh"

#include "../../ball_cover_types.hpp"
#include "../block_select_faiss.cuh"
#include "../faiss_select/key_value_block_select.cuh"
#include "../haversine_distance.cuh"
#include "../selection_faiss.cuh"

@@ -28,9 +28,6 @@

#include <raft/util/cuda_utils.cuh>

#include <faiss/gpu/utils/Limits.cuh>
#include <faiss/gpu/utils/Select.cuh>

#include <thrust/fill.h>

namespace raft {
@@ -172,32 +169,32 @@ __global__ void compute_final_dists_registers(const value_t* X_index,
dist_func dfunc,
value_int* dist_counter)
{
static constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize;
static constexpr int kNumWarps = tpb / WarpSize;

__shared__ value_t shared_memK[kNumWarps * warp_q];
__shared__ faiss::gpu::KeyValuePair<value_t, value_idx> shared_memV[kNumWarps * warp_q];
__shared__ KeyValuePair<value_t, value_idx> shared_memV[kNumWarps * warp_q];

const value_t* x_ptr = X + (n_cols * blockIdx.x);
value_t local_x_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_x_ptr[j] = x_ptr[j];
}

faiss::gpu::KeyValueBlockSelect<value_t,
value_idx,
false,
faiss::gpu::Comparator<value_t>,
warp_q,
thread_q,
tpb>
heap(faiss::gpu::Limits<value_t>::getMax(),
faiss::gpu::Limits<value_t>::getMax(),
faiss_select::KeyValueBlockSelect<value_t,
value_idx,
false,
faiss_select::Comparator<value_t>,
warp_q,
thread_q,
tpb>
heap(std::numeric_limits<value_t>::max(),
std::numeric_limits<value_t>::max(),
-1,
shared_memK,
shared_memV,
k);

const value_int n_k = faiss::gpu::utils::roundDown(k, faiss::gpu::kWarpSize);
const value_int n_k = Pow2<WarpSize>::roundDown(k);
value_int i = threadIdx.x;
for (; i < n_k; i += tpb) {
value_idx ind = knn_inds[blockIdx.x * k + i];
@@ -224,7 +221,7 @@ __global__ void compute_final_dists_registers(const value_t* X_index,
// Round R_size to the nearest warp threads so they can
// all be computing in parallel.

const value_int limit = faiss::gpu::utils::roundDown(R_size, faiss::gpu::kWarpSize);
const value_int limit = Pow2<WarpSize>::roundDown(R_size);

i = threadIdx.x;
for (; i < limit; i += tpb) {
@@ -334,10 +331,10 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index,
distance_func dfunc,
float weight = 1.0)
{
static constexpr value_int kNumWarps = tpb / faiss::gpu::kWarpSize;
static constexpr value_int kNumWarps = tpb / WarpSize;

__shared__ value_t shared_memK[kNumWarps * warp_q];
__shared__ faiss::gpu::KeyValuePair<value_t, value_idx> shared_memV[kNumWarps * warp_q];
__shared__ KeyValuePair<value_t, value_idx> shared_memV[kNumWarps * warp_q];

// TODO: Separate kernels for different widths:
// 1. Very small (between 3 and 32) just use registers for columns of "blockIdx.x"
@@ -352,15 +349,15 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index,
}

// Each warp works on 1 R
faiss::gpu::KeyValueBlockSelect<value_t,
value_idx,
false,
faiss::gpu::Comparator<value_t>,
warp_q,
thread_q,
tpb>
heap(faiss::gpu::Limits<value_t>::getMax(),
faiss::gpu::Limits<value_t>::getMax(),
faiss_select::KeyValueBlockSelect<value_t,
value_idx,
false,
faiss_select::Comparator<value_t>,
warp_q,
thread_q,
tpb>
heap(std::numeric_limits<value_t>::max(),
std::numeric_limits<value_t>::max(),
-1,
shared_memK,
shared_memV,
@@ -390,7 +387,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index,

value_idx R_size = R_stop_offset - R_start_offset;

value_int limit = faiss::gpu::utils::roundDown(R_size, faiss::gpu::kWarpSize);
value_int limit = Pow2<WarpSize>::roundDown(R_size);
value_int i = threadIdx.x;
for (; i < limit; i += tpb) {
// Index and distance of current candidate's nearest landmark
29 changes: 29 additions & 0 deletions cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file thirdparty/LICENSES/LICENSE.faiss
*/

#pragma once

#include <cuda.h>
#include <cuda_fp16.h>

namespace raft::spatial::knn::detail::faiss_select {

template <typename T>
struct Comparator {
__device__ static inline bool lt(T a, T b) { return a < b; }

__device__ static inline bool gt(T a, T b) { return a > b; }
};

template <>
struct Comparator<half> {
__device__ static inline bool lt(half a, half b) { return __hlt(a, b); }

__device__ static inline bool gt(half a, half b) { return __hgt(a, b); }
};

} // namespace raft::spatial::knn::detail::faiss_select
Loading

0 comments on commit 2dd9abb

Please sign in to comment.