Skip to content

Commit

Permalink
Fix multiplication of triangular and sparse matrix (JuliaLang#33579)
Browse files Browse the repository at this point in the history
  • Loading branch information
goggle authored and StefanKarpinski committed Oct 17, 2019
1 parent 6a82665 commit 954dc9b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
33 changes: 17 additions & 16 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import LinearAlgebra: checksquare
using Random: rand!

# In matrix-vector multiplication, the correct orientation of the vector is assumed.
const AdjOrTransStridedMatrix{T} = Union{StridedMatrix{T},Adjoint{<:Any,<:StridedMatrix{T}},Transpose{<:Any,<:StridedMatrix{T}}}
const StridedOrTriangularMatrix{T} = Union{StridedMatrix{T}, LowerTriangular{T}, UnitLowerTriangular{T}, UpperTriangular{T}, UnitUpperTriangular{T}}
const AdjOrTransStridedOrTriangularMatrix{T} = Union{StridedOrTriangularMatrix{T},Adjoint{<:Any,<:StridedOrTriangularMatrix{T}},Transpose{<:Any,<:StridedOrTriangularMatrix{T}}}

function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedMatrix}, α::Number, β::Number)
function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
Expand All @@ -27,10 +28,10 @@ function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVe
end
*(A::AbstractSparseMatrixCSC{TA}, x::StridedVector{Tx}) where {TA,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(x, T, size(A, 1)), A, x, true, false))
*(A::AbstractSparseMatrixCSC{TA}, B::StridedMatrix{Tx}) where {TA,Tx} =
*(A::AbstractSparseMatrixCSC{TA}, B::AdjOrTransStridedOrTriangularMatrix{Tx}) where {TA,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B, true, false))

function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedMatrix}, α::Number, β::Number)
function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
A = adjA.parent
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
Expand All @@ -53,10 +54,10 @@ function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}
end
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} =
(T = promote_op(matprod, eltype(adjA), Tx); mul!(similar(x, T, size(adjA, 1)), adjA, x, true, false))
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedMatrix) =
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix) =
(T = promote_op(matprod, eltype(adjA), eltype(B)); mul!(similar(B, T, (size(adjA, 1), size(B, 2))), adjA, B, true, false))

function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedMatrix}, α::Number, β::Number)
function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
A = transA.parent
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
Expand All @@ -79,19 +80,19 @@ function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrix
end
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} =
(T = promote_op(matprod, eltype(transA), Tx); mul!(similar(x, T, size(transA, 1)), transA, x, true, false))
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedMatrix) =
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix) =
(T = promote_op(matprod, eltype(transA), eltype(B)); mul!(similar(B, T, (size(transA, 1), size(B, 2))), transA, B, true, false))

# For compatibility with dense multiplication API. Should be deleted when dense multiplication
# API is updated to follow BLAS API.
mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedMatrix}) =
mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}) =
mul!(C, A, B, true, false)
mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedMatrix}) =
mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}) =
mul!(C, adjA, B, true, false)
mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedMatrix}) =
mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}) =
mul!(C, transA, B, true, false)

function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, A::AbstractSparseMatrixCSC, α::Number, β::Number)
function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::AbstractSparseMatrixCSC, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) || throw(DimensionMismatch())
mX == size(C, 1) || throw(DimensionMismatch())
Expand All @@ -106,10 +107,10 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, A::AbstractSparseM
end
C
end
*(X::AdjOrTransStridedMatrix, A::AbstractSparseMatrixCSC{TvA}) where {TvA} =
*(X::AdjOrTransStridedOrTriangularMatrix, A::AbstractSparseMatrixCSC{TvA}) where {TvA} =
(T = promote_op(matprod, eltype(X), TvA); mul!(similar(X, T, (size(X, 1), size(A, 2))), X, A, true, false))

function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
A = adjA.parent
mX, nX = size(X)
nX == size(A, 2) || throw(DimensionMismatch())
Expand All @@ -125,10 +126,10 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, adjA::Adjoint{<:An
end
C
end
*(X::AdjOrTransStridedMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) =
*(X::AdjOrTransStridedOrTriangularMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) =
(T = promote_op(matprod, eltype(X), eltype(adjA)); mul!(similar(X, T, (size(X, 1), size(adjA, 2))), X, adjA, true, false))

function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
A = transA.parent
mX, nX = size(X)
nX == size(A, 2) || throw(DimensionMismatch())
Expand All @@ -144,7 +145,7 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, transA::Transpose{
end
C
end
*(X::AdjOrTransStridedMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}) =
*(X::AdjOrTransStridedOrTriangularMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}) =
(T = promote_op(matprod, eltype(X), eltype(transA)); mul!(similar(X, T, (size(X, 1), size(transA, 2))), X, transA, true, false))

function (*)(D::Diagonal, A::AbstractSparseMatrixCSC)
Expand Down
23 changes: 23 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,29 @@ end
end
end

@testset "multiplication of sparse matrix and triangular matrix" begin
_sparse_test_matrix(n, T) = T == Int ? sparse(rand(0:4, n, n)) : sprandn(T, n, n, 0.6)
_triangular_test_matrix(n, TA, T) = T == Int ? TA(rand(0:9, n, n)) : TA(randn(T, n, n))

n = 5
for T1 in (Int, Float64, ComplexF32)
S = _sparse_test_matrix(n, T1)
MS = Matrix(S)
for T2 in (Int, Float64, ComplexF32)
for TM in (LowerTriangular, UnitLowerTriangular, UpperTriangular, UnitLowerTriangular)
T = _triangular_test_matrix(n, TM, T2)
MT = Matrix(T)
@test isa(T * S, DenseMatrix)
@test isa(S * T, DenseMatrix)
for transT in (identity, adjoint, transpose), transS in (identity, adjoint, transpose)
@test transT(T) * transS(S) transT(MT) * transS(MS)
@test transS(S) * transT(T) transS(MS) * transT(MT)
end
end
end
end
end

@testset "Issue #30502" begin
@test nnz(sprand(UInt8(16), UInt8(16), 1.0)) == 256
@test nnz(sprand(UInt8(16), UInt8(16), 1.0, ones)) == 256
Expand Down

0 comments on commit 954dc9b

Please sign in to comment.