Skip to content

Commit

Permalink
Use correct transposition symbols in Ax_ldiv_B! for UMFPACK. Fixes Ju…
Browse files Browse the repository at this point in the history
…liaLang#21877 (JuliaLang#21889)

Make sure the methods are getting called from A'\b and A.'\b.

Don't promote Factorizations in general

Promote Bunch-Kaufman since it doesn't have a generic fallback solver
  • Loading branch information
andreasnoack authored May 16, 2017
1 parent 149e589 commit 57ebd82
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 22 deletions.
5 changes: 5 additions & 0 deletions base/linalg/bunchkaufman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ function A_ldiv_B!(B::BunchKaufman{T}, R::StridedVecOrMat{T}) where T<:BlasCompl
end
end
end
# There is no fallback solver for Bunch-Kaufman so we'll have to promote to same element type
function A_ldiv_B!(B::BunchKaufman{T}, R::StridedVecOrMat{S}) where {T,S}
TS = promote_type(T,S)
return A_ldiv_B!(convert(BunchKaufman{TS}, B), convert(AbstractArray{TS}, R))
end

function logabsdet(F::BunchKaufman)
M = F.LD
Expand Down
2 changes: 1 addition & 1 deletion base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ for (f1, f2) in ((:\, :A_ldiv_B!),
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
BB = similar(B, TFB, size(B))
copy!(BB, B)
$f2(convert(Factorization{TFB}, F), BB)
$f2(F, BB)
end
end
end
Expand Down
37 changes: 20 additions & 17 deletions base/sparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -851,29 +851,32 @@ end
scale!(A::SparseMatrixCSC, b::Number) = (scale!(A.nzval, b); A)
scale!(b::Number, A::SparseMatrixCSC) = (scale!(b, A.nzval); A)

function (\)(A::SparseMatrixCSC, B::AbstractVecOrMat)
m, n = size(A)
if m == n
if istril(A)
if istriu(A)
return Diagonal(A) \ B
for f in (:\, :Ac_ldiv_B, :At_ldiv_B)
@eval begin
function ($f)(A::SparseMatrixCSC, B::AbstractVecOrMat)
m, n = size(A)
if m == n
if istril(A)
if istriu(A)
return ($f)(Diagonal(A), B)
else
return ($f)(LowerTriangular(A), B)
end
elseif istriu(A)
return ($f)(UpperTriangular(A), B)
end
if ishermitian(A)
return ($f)(Hermitian(A), B)
end
return ($f)(lufact(A), B)
else
return LowerTriangular(A) \ B
return ($f)(qrfact(A), B)
end
elseif istriu(A)
return UpperTriangular(A) \ B
end
if ishermitian(A)
return Hermitian(A) \ B
end
return lufact(A) \ B
else
return qrfact(A) \ B
($f)(::SparseMatrixCSC, ::RowVector) = throw(DimensionMismatch("Cannot left-divide matrix by transposed vector"))
end
end

(\)(::SparseMatrixCSC, ::RowVector) = throw(DimensionMismatch("Cannot left-divide matrix by transposed vector"))

function factorize(A::SparseMatrixCSC)
m, n = size(A)
if m == n
Expand Down
8 changes: 4 additions & 4 deletions base/sparse/umfpack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,15 +392,15 @@ Ac_ldiv_B!(lu::UmfpackLU{Float64}, B::StridedVecOrMat{<:Complex}) = Ac_ldiv_B!(B
A_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_A)
At_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_At)
Ac_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_Aat)
Ac_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_At)
A_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_A)
At_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_At)
Ac_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_Aat)
Ac_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_At)

function _Aq_ldiv_B!(X::StridedVecOrMat, lu::UmfpackLU, B::StridedVecOrMat, transposeoptype)
if size(X, 2) != size(B, 2)
Expand Down
7 changes: 7 additions & 0 deletions test/sparse/umfpack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,16 @@ end
Ac0 = complex.(A0,A0)
for Ti in Base.uniontypes(Base.SparseArrays.UMFPACK.UMFITypes)
Ac = convert(SparseMatrixCSC{Complex128,Ti}, Ac0)
x = complex.(ones(size(Ac, 1)), ones(size(Ac,1)))
lua = lufact(Ac)
L,U,p,q,Rs = lua[:(:)]
@test (Diagonal(Rs) * Ac)[p,q] L * U
b = Ac*x
@test Ac\b x
b = Ac'*x
@test Ac'\b x
b = Ac.'*x
@test Ac.'\b x
end

for elty in (Float64, Complex128)
Expand Down

0 comments on commit 57ebd82

Please sign in to comment.