Skip to content

Commit

Permalink
Use ALG2 for SpMM in cuSparse (dmlc#2550)
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-dlasalle authored Jan 21, 2021
1 parent 2c6d071 commit 9d90faf
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions src/array/cuda/spmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ void CusparseCsrmm2(
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
// allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
// all one data array
DType* valptr = nullptr;
if (!A_data) {
Expand All @@ -142,32 +140,35 @@ void CusparseCsrmm2(
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, cuda_dtype));
CUSPARSE_CALL(cusparseCreateDnMat(&matB,
n, k, n,
const_cast<DType*>(B_data), cuda_dtype, CUSPARSE_ORDER_COL));
k, n, n,
const_cast<DType*>(B_data), cuda_dtype, CUSPARSE_ORDER_ROW));
CUSPARSE_CALL(cusparseCreateDnMat(&matC,
m, n, m,
trans_out, cuda_dtype, CUSPARSE_ORDER_COL));
m, n, n,
C_data, cuda_dtype, CUSPARSE_ORDER_ROW));

auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
size_t workspace_size;
CUSPARSE_CALL(cusparseSpMM_bufferSize(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC,
cuda_dtype, CUSPARSE_CSRMM_ALG1,
cuda_dtype, CUSPARSE_SPMM_CSR_ALG2,
&workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseSpMM(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC,
cuda_dtype, CUSPARSE_CSRMM_ALG1,
cuda_dtype, CUSPARSE_SPMM_CSR_ALG2,
workspace));
device->FreeWorkspace(ctx, workspace);

CUSPARSE_CALL(cusparseDestroySpMat(matA));
CUSPARSE_CALL(cusparseDestroyDnMat(matB));
CUSPARSE_CALL(cusparseDestroyDnMat(matC));
#else
// allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));

cusparseMatDescr_t descr;
CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
Expand All @@ -182,9 +183,6 @@ void CusparseCsrmm2(
static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, trans_out, m));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
#endif
if (valptr)
device->FreeWorkspace(ctx, valptr);
// transpose the output matrix
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
Expand All @@ -198,6 +196,9 @@ void CusparseCsrmm2(
&beta, nullptr, n,
C_data, n));
device->FreeWorkspace(ctx, trans_out);
#endif
if (valptr)
device->FreeWorkspace(ctx, valptr);
}
} // namespace cusparse

Expand Down

0 comments on commit 9d90faf

Please sign in to comment.