Skip to content

Commit

Permalink
Fix SpMV for CUDA 11.5
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored and maleadt committed Nov 8, 2021
1 parent 8e5b4a2 commit d106de5
Showing 1 changed file with 8 additions and 33 deletions.
41 changes: 8 additions & 33 deletions lib/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,52 +28,27 @@ end
tag_wrappers = ((identity, identity),
(T -> :(HermOrSym{T, <:$T}), A -> :(parent($A))))
op_wrappers = (
(identity, 'N', identity),
(T -> :(Transpose{<:T, <:$T}), 'T', A -> :(parent($A))),
(T -> :(Adjoint{<:T, <:$T}), 'C', A -> :(parent($A)))
(identity, T -> 'N', identity),
(T -> :(Transpose{<:T, <:$T}), T -> 'T', A -> :(parent($A))),
(T -> :(Adjoint{<:T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A)))
)
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(taga(:(CuSparseMatrix{T})))

@eval begin
function LinearAlgebra.mul!(C::CuVector{T}, A::$TypeA, B::DenseCuVector{T},
alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
mv_wrapper($transa, alpha, $(untaga(unwrapa(:A))), B, beta, C)
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
end
end

for (tagb, untagb) in tag_wrappers, (wrapb, transb, unwrapb) in op_wrappers
TypeB = wrapb(tagb(:(DenseCuMatrix{T})))

isadjoint(expr) = expr.head == :curly && expr.args[1] == :Adjoint
if isadjoint(TypeA) || isadjoint(TypeB)
# CUSPARSE defines adjoints only for complex inputs. For real inputs we run
# adjoints as tranposes so that we can still support the whole API surface of
# LinearAlgebra.
@eval begin
function LinearAlgebra.mul!(C::CuMatrix{T}, A::$TypeA, B::$TypeB,
alpha::Number, beta::Number) where {T <: Union{ComplexF16, BlasComplex}}
mm_wrapper($transa, $transb, alpha, $(untaga(unwrapa(:A))),
$(untagb(unwrapb(:B))), beta, C)
end
end

transa_real = transa == 'C' ? 'T' : transa
transb_real = transb == 'C' ? 'T' : transb
@eval begin
function LinearAlgebra.mul!(C::CuMatrix{T}, A::$TypeA, B::$TypeB,
alpha::Number, beta::Number) where {T <: Union{Float16, BlasReal}}
mm_wrapper($transa_real, $transb_real, alpha, $(untaga(unwrapa(:A))),
$(untagb(unwrapb(:B))), beta, C)
end
end
else
@eval begin
function LinearAlgebra.mul!(C::CuMatrix{T}, A::$TypeA, B::$TypeB,
alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
mm_wrapper($transa, $transb, alpha, $(untaga(unwrapa(:A))),
$(untagb(unwrapb(:B))), beta, C)
end
@eval begin
function LinearAlgebra.mul!(C::CuMatrix{T}, A::$TypeA, B::$TypeB,
alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
mm_wrapper($transa(T), $transb(T), alpha, $(untaga(unwrapa(:A))), $(untagb(unwrapb(:B))), beta, C)
end
end
end
Expand Down

0 comments on commit d106de5

Please sign in to comment.