Skip to content

Commit

Permalink
[Performance][GPU] Improve _SegmentCopyKernel() (dmlc#3470)
Browse files Browse the repository at this point in the history
* Based on issue dmlc#3436. Improving _SegmentCopyKernel s GPU utilization by switching to nonzero based thread assignment

* fixing lint issues

* Update cub for cuda 11.5 compatibility (dmlc#3468)

* fixing type mismatch

* tx guaranteed to be smaller than nnz. Hence removing last check

* minor: updating comment

* adding three unit tests for csr slice method to cover some corner cases

Co-authored-by: Abdurrahman Yasar <[email protected]>
Co-authored-by: nv-dlasalle <[email protected]>
Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
4 people authored Nov 6, 2021
1 parent efe0b06 commit 96cd2ee
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 11 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@ Contributors
* [@xnuohz](https://github.com/xnuohz)
* [Hao Jin](https://github.com/haojin2) from Amazon
* [Xin Yao](https://github.com/yaox12) from Nvidia
* [Abdurrahman Yasar](https://github.com/ayasar70) from Nvidia
* [Shaked Brody](https://github.com/shakedbr) from Technion
33 changes: 22 additions & 11 deletions src/array/cuda/spmat_op_impl_csr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,27 @@ template CSRMatrix CSRSliceRows<kDLGPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <typename IdType, typename DType>
__global__ void _SegmentCopyKernel(
const IdType* indptr, const DType* data,
const IdType* row, int64_t row_stride, int64_t length,
const IdType* row, int64_t length, int64_t n_row,
const IdType* out_indptr, DType* out_data) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
int rpos = tx * row_stride;
const IdType r = row[rpos];
DType* out_buf = out_data + out_indptr[tx];
for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
*(out_buf++) = data? data[i] : i;
// find upper bound for tx using binary search.
// out_indptr has already a prefix sum. n_row = size(out_indptr)-1
IdType l = 0, r = n_row, m = 0;
while (l < r) {
m = l + (r-l)/2;
if (tx >= out_indptr[m]) {
l = m+1;
} else {
r = m;
}
}

IdType rpos = l-1;
IdType rofs = tx - out_indptr[rpos];
const IdType u = row[rpos];
out_data[tx] = data? data[indptr[u]+rofs] : indptr[u]+rofs;
tx += stride_x;
}
}
Expand All @@ -250,21 +260,22 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
IdArray ret_indptr = aten::CumSum(aten::CSRGetRowNNZ(csr, rows), true);
const int64_t nnz = aten::IndexSelect<IdType>(ret_indptr, len);

const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
const int nt = 256; // for better GPU usage of small invocations
const int nb = (nnz + nt - 1) / nt;

// Copy indices.
IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
CUDA_KERNEL_CALL(_SegmentCopyKernel,
nb, nt, 0, thr_entry->stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
rows.Ptr<IdType>(), 1, len,
rows.Ptr<IdType>(), nnz, len,
ret_indptr.Ptr<IdType>(), ret_indices.Ptr<IdType>());
// Copy data.
IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
CUDA_KERNEL_CALL(_SegmentCopyKernel,
nb, nt, 0, thr_entry->stream,
csr.indptr.Ptr<IdType>(), CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr,
rows.Ptr<IdType>(), 1, len,
rows.Ptr<IdType>(), nnz, len,
ret_indptr.Ptr<IdType>(), ret_data.Ptr<IdType>());
return CSRMatrix(len, csr.num_cols,
ret_indptr, ret_indices, ret_data,
Expand Down
58 changes: 58 additions & 0 deletions tests/cpp/test_spmat_csr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,64 @@ void _TestCSRSliceRows(DLContext ctx) {
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));

// Testing non-increasing row id based slicing
r = aten::VecToIdArray(std::vector<IDX>({3, 2, 1}), sizeof(IDX)*8, ctx);
x = aten::CSRSliceRows(csr, r);
// [[0, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [1, 0, 0, 0, 0]]
// data: [1, 4, 3]
tp = aten::VecToIdArray(std::vector<IDX>({0, 0, 2, 3}), sizeof(IDX)*8, ctx);
ti = aten::VecToIdArray(std::vector<IDX>({2, 3, 0}), sizeof(IDX)*8, ctx);
td = aten::VecToIdArray(std::vector<IDX>({1, 4, 3}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));

// Testing zero-degree row slicing with different rows
r = aten::VecToIdArray(std::vector<IDX>({1, 3, 0, 3, 2}), sizeof(IDX)*8, ctx);
x = aten::CSRSliceRows(csr, r);
// [[1, 0, 0, 0, 0],
// [0, 0, 0, 0, 0],
// [0, 1, 2, 0, 0],
// [0, 0, 0, 0, 0],
// [0, 0, 1, 1, 0]]
// data: [3, 0, 2, 5, 1, 4]
tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 1, 4, 4, 6}), sizeof(IDX)*8, ctx);
ti = aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 2, 2, 3}), sizeof(IDX)*8, ctx);
td = aten::VecToIdArray(std::vector<IDX>({3, 0, 2, 5, 1, 4}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));

// Testing empty output (i.e. sliced rows will be zero-degree)
r = aten::VecToIdArray(std::vector<IDX>({3,3,3}), sizeof(IDX)*8, ctx);
x = aten::CSRSliceRows(csr, r);
// [[0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0]]
// data: []
tp = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 0}), sizeof(IDX)*8, ctx);
ti = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
td = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));

// Testing constant output: we pick last row with at least one nnz
r = aten::VecToIdArray(std::vector<IDX>({2,2,2}), sizeof(IDX)*8, ctx);
x = aten::CSRSliceRows(csr, r);
// [[0, 0, 1, 1, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 1, 1, 0]]
// data: [1, 4, 1, 4, 1, 4]
tp = aten::VecToIdArray(std::vector<IDX>({0, 2, 4, 6}), sizeof(IDX)*8, ctx);
ti = aten::VecToIdArray(std::vector<IDX>({2, 3, 2, 3, 2, 3}), sizeof(IDX)*8, ctx);
td = aten::VecToIdArray(std::vector<IDX>({1, 4, 1, 4, 1, 4}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
}

TEST(SpmatTest, TestCSRSliceRows) {
Expand Down

0 comments on commit 96cd2ee

Please sign in to comment.