Skip to content

Commit

Permalink
[CUSPARSE] Update mv! and mm! functions for CuSparseMatrixCOO and CuS…
Browse files Browse the repository at this point in the history
…parseMatrixCSC (JuliaGPU#1592)
  • Loading branch information
amontoison authored Oct 5, 2022
1 parent 49791e8 commit c5072c0
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 161 deletions.
175 changes: 62 additions & 113 deletions lib/cusparse/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ mutable struct CuSparseMatrixDescriptor
return obj
end

function CuSparseMatrixDescriptor(A::CuSparseMatrixCSC, IndexBase::Char; convert=true)
function CuSparseMatrixDescriptor(A::CuSparseMatrixCSC, IndexBase::Char; convert=false)
desc_ref = Ref{cusparseSpMatDescr_t}()
if convert
# many algorithms, e.g. mv! and mm!, do not support CSC sparse format
Expand Down Expand Up @@ -123,17 +123,35 @@ function gather!(X::CuSparseVector, Y::CuVector, index::SparseChar)
X
end

function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixBSR{TA},CuSparseMatrixCSR{TA}},
X::DenseCuVector{T}, beta::Number, Y::DenseCuVector{T}, index::SparseChar, algo::cusparseSpMVAlg_t=CUSPARSE_MV_ALG_DEFAULT) where {TA, T}
m,n = size(A)
function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{TA},CuSparseMatrixCSR{TA},CuSparseMatrixCOO{TA}}, X::DenseCuVector{T},
beta::Number, Y::DenseCuVector{T}, index::SparseChar, algo::cusparseSpMVAlg_t=CUSPARSE_MV_ALG_DEFAULT) where {TA, T}

# Support transa = 'C' for real matrices
transa = T <: Real && transa == 'C' ? 'T' : transa

if isa(A, CuSparseMatrixCSC) && transa == 'C' && TA <: Complex
throw(ArgumentError("Matrix-vector multiplication with the adjoint of a complex CSC matrix" *
" is not supported. Use a CSR or COO matrix instead."))
end

if isa(A, CuSparseMatrixCSC)
# cusparseSpMV doesn't support CSC format with CUSPARSE.version() < v"11.6.1"
# cusparseSpMV supports the CSC format with CUSPARSE.version() ≥ v"11.6.1"
# but it doesn't work for complex numbers when transa == 'C'
descA = CuSparseMatrixDescriptor(A, index, convert=true)
n,m = size(A)
transa = transa == 'N' ? 'T' : 'N'
else
descA = CuSparseMatrixDescriptor(A, index)
m,n = size(A)
end

if transa == 'N'
chkmvdims(X,n,Y,m)
elseif transa == 'T' || transa == 'C'
chkmvdims(X,m,Y,n)
end

descA = CuSparseMatrixDescriptor(A, index)
descX = CuDenseVectorDescriptor(X)
descY = CuDenseVectorDescriptor(Y)

Expand All @@ -158,59 +176,32 @@ function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixBSR{TA},C
cusparseSpMV(handle(), transa, Ref{compute_type}(alpha), descA, descX, Ref{compute_type}(beta),
descY, compute_type, algo, buffer)
end
Y
return Y
end

function mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSC{TA}, X::DenseCuVector{T},
beta::Number, Y::DenseCuVector{T}, index::SparseChar, algo::cusparseSpMVAlg_t=CUSPARSE_MV_ALG_DEFAULT) where {TA, T}
ctransa = 'N'
if transa == 'N'
ctransa = 'T'
elseif transa == 'C' && TA <: Complex
throw(ArgumentError("Matrix-vector multiplication with the adjoint of a CSC matrix" *
" is not supported. Use a CSR matrix instead."))
end
function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}},
B::DenseCuMatrix{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_MM_ALG_DEFAULT) where {T}

n,m = size(A)
# Support transa = 'C' and `transb = 'C' for real matrices
transa = T <: Real && transa == 'C' ? 'T' : transa
transb = T <: Real && transb == 'C' ? 'T' : transb

if ctransa == 'N'
chkmvdims(X,n,Y,m)
elseif ctransa == 'T' || ctransa == 'C'
chkmvdims(X,m,Y,n)
if isa(A, CuSparseMatrixCSC) && transa == 'C' && T <: Complex
throw(ArgumentError("Matrix-matrix multiplication with the adjoint of a complex CSC matrix" *
" is not supported. Use a CSR and COO matrix instead."))
end

descA = CuSparseMatrixDescriptor(A, index)
descX = CuDenseVectorDescriptor(X)
descY = CuDenseVectorDescriptor(Y)

# operations with 16-bit numbers always imply mixed-precision computation
# TODO: we should better model the supported combinations here,
# and error if using an unsupported one (like with gemmEx!)
compute_type = if version() >= v"11.4" && T == Float16
Float32
elseif version() >= v"11.7.2" && T == ComplexF16
ComplexF32
if isa(A, CuSparseMatrixCSC)
# cusparseSpMM doesn't support CSC format with CUSPARSE.version() < v"11.6.1"
# cusparseSpMM supports the CSC format with CUSPARSE.version() ≥ v"11.6.1"
# but it doesn't work for complex numbers when transa == 'C'
descA = CuSparseMatrixDescriptor(A, index, convert=true)
k,m = size(A)
transa = transa == 'N' ? 'T' : 'N'
else
T
end

function bufferSize()
out = Ref{Csize_t}()
cusparseSpMV_bufferSize(handle(), ctransa, Ref{compute_type}(alpha), descA, descX, Ref{compute_type}(beta),
descY, compute_type, algo, out)
return out[]
end
with_workspace(bufferSize) do buffer
cusparseSpMV(handle(), ctransa, Ref{compute_type}(alpha), descA, descX, Ref{compute_type}(beta),
descY, compute_type, algo, buffer)
descA = CuSparseMatrixDescriptor(A, index)
m,k = size(A)
end

return Y
end

function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T},
B::DenseCuMatrix{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_MM_ALG_DEFAULT) where {T}
m,k = size(A)
n = size(C)[2]

if transa == 'N' && transb == 'N'
Expand All @@ -223,7 +214,6 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseM
chkmmdims(B,C,n,m,k,n)
end

descA = CuSparseMatrixDescriptor(A, index)
descB = CuDenseMatrixDescriptor(B)
descC = CuDenseMatrixDescriptor(C)

Expand All @@ -239,50 +229,6 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseM
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
descC, T, algo, buffer)
end

return C
end

function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T},
B::DenseCuMatrix{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_MM_ALG_DEFAULT) where {T}
ctransa = 'N'
if transa == 'N'
ctransa = 'T'
elseif transa == 'C' && T <: Complex
throw(ArgumentError("Matrix-matrix multiplication with the adjoint of a CSC matrix" *
" is not supported. Use a CSR matrix instead."))
end

k,m = size(A)
n = size(C)[2]

if ctransa == 'N' && transb == 'N'
chkmmdims(B,C,k,n,m,n)
elseif ctransa == 'N' && transb != 'N'
chkmmdims(B,C,n,k,m,n)
elseif ctransa != 'N' && transb == 'N'
chkmmdims(B,C,m,n,k,n)
elseif ctransa != 'N' && transb != 'N'
chkmmdims(B,C,n,m,k,n)
end

descA = CuSparseMatrixDescriptor(A, index)
descB = CuDenseMatrixDescriptor(B)
descC = CuDenseMatrixDescriptor(C)

function bufferSize()
out = Ref{Csize_t}()
cusparseSpMM_bufferSize(
handle(), ctransa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
descC, T, algo, out)
return out[]
end
with_workspace(bufferSize) do buffer
cusparseSpMM(
handle(), ctransa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
descC, T, algo, buffer)
end

return C
end

Expand Down Expand Up @@ -363,8 +309,11 @@ function sv!(transa::SparseChar, uplo::SparseChar, diag::SparseChar,
alpha::Number, A::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}}, X::DenseCuVector{T},
Y::DenseCuVector{T}, index::SparseChar, algo::cusparseSpSVAlg_t=CUSPARSE_SPSV_ALG_DEFAULT) where {T}

# Support transa = 'C' for real matrices
transa = T <: Real && transa == 'C' ? 'T' : transa

if isa(A, CuSparseMatrixCSC) && transa == 'C' && T <: Complex
throw(ArgumentError("Backward and forward sweeps with the adjoint of a CSC matrix is not supported. Use a CSR or COO matrix instead."))
throw(ArgumentError("Backward and forward sweeps with the adjoint of a complex CSC matrix is not supported. Use a CSR or COO matrix instead."))
end

mA,nA = size(A)
Expand All @@ -378,15 +327,13 @@ function sv!(transa::SparseChar, uplo::SparseChar, diag::SparseChar,
if isa(A, CuSparseMatrixCSC)
# cusparseSpSV doesn't support CSC format
descA = CuSparseMatrixDescriptor(A, index, convert=true)
transa2 = transa == 'N' ? 'T' : 'N'
uplo2 = uplo == 'U' ? 'L' : 'U'
transa = transa == 'N' ? 'T' : 'N'
uplo = uplo == 'U' ? 'L' : 'U'
else
descA = CuSparseMatrixDescriptor(A, index)
transa2 = transa
uplo2 = uplo
end

cusparse_uplo = Ref{cusparseFillMode_t}(uplo2)
cusparse_uplo = Ref{cusparseFillMode_t}(uplo)
cusparse_diag = Ref{cusparseDiagType_t}(diag)

cusparseSpMatSetAttribute(descA, 'F', cusparse_uplo, Csize_t(sizeof(cusparse_uplo)))
Expand All @@ -398,12 +345,12 @@ function sv!(transa::SparseChar, uplo::SparseChar, diag::SparseChar,
spsv_desc = CuSparseSpSVDescriptor()
function bufferSize()
out = Ref{Csize_t}()
cusparseSpSV_bufferSize(handle(), transa2, Ref{T}(alpha), descA, descX, descY, T, algo, spsv_desc, out)
cusparseSpSV_bufferSize(handle(), transa, Ref{T}(alpha), descA, descX, descY, T, algo, spsv_desc, out)
return out[]
end
with_workspace(bufferSize) do buffer
cusparseSpSV_analysis(handle(), transa2, Ref{T}(alpha), descA, descX, descY, T, algo, spsv_desc, buffer)
cusparseSpSV_solve(handle(), transa2, Ref{T}(alpha), descA, descX, descY, T, algo, spsv_desc)
cusparseSpSV_analysis(handle(), transa, Ref{T}(alpha), descA, descX, descY, T, algo, spsv_desc, buffer)
cusparseSpSV_solve(handle(), transa, Ref{T}(alpha), descA, descX, descY, T, algo, spsv_desc)
end
return Y
end
Expand All @@ -412,8 +359,12 @@ function sm!(transa::SparseChar, transb::SparseChar, uplo::SparseChar, diag::Spa
alpha::Number, A::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}}, B::DenseCuMatrix{T},
C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpSMAlg_t=CUSPARSE_SPSM_ALG_DEFAULT) where {T}

# Support transa = 'C' and `transb = 'C' for real matrices
transa = T <: Real && transa == 'C' ? 'T' : transa
transb = T <: Real && transb == 'C' ? 'T' : transb

if isa(A, CuSparseMatrixCSC) && transa == 'C' && T <: Complex
throw(ArgumentError("Backward and forward sweeps with the adjoint of a CSC matrix is not supported. Use a CSR or COO matrix instead."))
throw(ArgumentError("Backward and forward sweeps with the adjoint of a complex CSC matrix is not supported. Use a CSR or COO matrix instead."))
end

mA,nA = size(A)
Expand All @@ -430,15 +381,13 @@ function sm!(transa::SparseChar, transb::SparseChar, uplo::SparseChar, diag::Spa
if isa(A, CuSparseMatrixCSC)
# cusparseSpSM doesn't support CSC format
descA = CuSparseMatrixDescriptor(A, index, convert=true)
transa2 = transa == 'N' ? 'T' : 'N'
uplo2 = uplo == 'U' ? 'L' : 'U'
transa = transa == 'N' ? 'T' : 'N'
uplo = uplo == 'U' ? 'L' : 'U'
else
descA = CuSparseMatrixDescriptor(A, index)
transa2 = transa
uplo2 = uplo
end

cusparse_uplo = Ref{cusparseFillMode_t}(uplo2)
cusparse_uplo = Ref{cusparseFillMode_t}(uplo)
cusparse_diag = Ref{cusparseDiagType_t}(diag)

cusparseSpMatSetAttribute(descA, 'F', cusparse_uplo, Csize_t(sizeof(cusparse_uplo)))
Expand All @@ -450,12 +399,12 @@ function sm!(transa::SparseChar, transb::SparseChar, uplo::SparseChar, diag::Spa
spsm_desc = CuSparseSpSMDescriptor()
function bufferSize()
out = Ref{Csize_t}()
cusparseSpSM_bufferSize(handle(), transa2, transb, Ref{T}(alpha), descA, descB, descC, T, algo, spsm_desc, out)
cusparseSpSM_bufferSize(handle(), transa, transb, Ref{T}(alpha), descA, descB, descC, T, algo, spsm_desc, out)
return out[]
end
with_workspace(bufferSize) do buffer
cusparseSpSM_analysis(handle(), transa2, transb, Ref{T}(alpha), descA, descB, descC, T, algo, spsm_desc, buffer)
cusparseSpSM_solve(handle(), transa2, transb, Ref{T}(alpha), descA, descB, descC, T, algo, spsm_desc)
cusparseSpSM_analysis(handle(), transa, transb, Ref{T}(alpha), descA, descB, descC, T, algo, spsm_desc, buffer)
cusparseSpSM_solve(handle(), transa, transb, Ref{T}(alpha), descA, descB, descC, T, algo, spsm_desc)
end
return C
end
10 changes: 5 additions & 5 deletions lib/cusparse/level2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
export sv2!, sv2, mv!

"""
mv!(transa::SparseChar, alpha::BlasFloat, A::CuSparseMatrix, X::CuVector, beta::BlasFloat, Y::CuVector, index::SparseChar)
mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::CuVector, beta::Number, Y::CuVector, index::SparseChar)
Performs `Y = alpha * op(A) * X + beta * Y`, where `op` can be nothing (`transa = N`),
tranpose (`transa = T`) or conjugate transpose (`transa = C`). `X` is a sparse vector, and
`Y` is dense.
tranpose (`transa = T`) or conjugate transpose (`transa = C`).
`X` and `Y` are dense vectors.
"""
mv!(transa::SparseChar, alpha::BlasFloat, A::CuSparseMatrix, X::CuVector, beta::BlasFloat, Y::CuVector, index::SparseChar)
mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::CuVector, beta::Number, Y::CuVector, index::SparseChar)

for (fname,elty) in ((:cusparseSbsrmv, :Float32),
(:cusparseDbsrmv, :Float64),
Expand Down Expand Up @@ -175,7 +175,7 @@ for (bname,aname,sname,elty) in ((:cusparseScsrsv2_bufferSize, :cusparseScsrsv2_
if transa == 'N'
ctransa = 'T'
elseif transa == 'C' && eltype(A) <: Complex
throw(ArgumentError("Backward and forward sweeps with the adjoint of a CSC matrix is not supported. Use a CSR matrix instead."))
throw(ArgumentError("Backward and forward sweeps with the adjoint of a complex CSC matrix is not supported. Use a CSR matrix instead."))
end
if uplo == 'U'
cuplo = 'L'
Expand Down
8 changes: 4 additions & 4 deletions lib/cusparse/level3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
export mm!, mm2!, sm2!, sm2

"""
mm!(transa::SparseChar, transb::SparseChar, alpha::BlasFloat, A::CuSparseMatrix, B::CuMatrix, beta::BlasFloat, C::CuMatrix, index::SparseChar)
mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix, B::CuMatrix, beta::Number, C::CuMatrix, index::SparseChar)
Performs `C = alpha * op(A) * op(B) + beta * C`, where `op` can be nothing (`transa = N`),
tranpose (`transa = T`) or conjugate transpose (`transa = C`).
`A` is a sparse matrix defined in BSR storage format. `B` and `C` are dense matrices.
`B` and `C` are dense matrices.
"""
mm!(transa::SparseChar, transb::SparseChar, alpha::BlasFloat, A::CuSparseMatrix, B::CuMatrix, beta::BlasFloat, C::CuMatrix, index::SparseChar)
mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix, B::CuMatrix, beta::Number, C::CuMatrix, index::SparseChar)

# bsrmm
for (fname,elty) in ((:cusparseSbsrmm, :Float32),
Expand Down Expand Up @@ -276,7 +276,7 @@ for (bname,aname,sname,elty) in ((:cusparseScsrsm2_bufferSizeExt, :cusparseScsrs
if transa == 'N'
ctransa = 'T'
elseif transa == 'C' && eltype(A) <: Complex
throw(ArgumentError("Backward and forward sweeps with the adjoint of a CSC matrix is not supported. Use a CSR matrix instead."))
throw(ArgumentError("Backward and forward sweeps with the adjoint of a complex CSC matrix is not supported. Use a CSR matrix instead."))
end
if uplo == 'U'
cuplo = 'L'
Expand Down
4 changes: 2 additions & 2 deletions lib/cusparse/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ function chkmvdims(X, n, Y, m)
end

"""
check that the dimensions of matrices `X` and `Y` make sense for a multiplication
check that the dimensions of matrices `B` and `C` make sense for a multiplication
"""
function chkmmdims( B, C, k, l, m, n )
function chkmmdims(B, C, k, l, m, n)
if size(B) != (k,l)
throw(DimensionMismatch("B has dimensions $(size(B)) but needs ($k,$l)"))
elseif size(C) != (m,n)
Expand Down
Loading

0 comments on commit c5072c0

Please sign in to comment.