Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Nov 4, 2013
1 parent d720865 commit d46936b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
23 changes: 17 additions & 6 deletions base/linalg/umfpack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,23 +247,34 @@ end

(\){T<:UMFVTypes}(fact::UmfpackLU{T}, b::Vector{T}) = solve(fact, b)
(\){Ts<:UMFVTypes,Tb<:Number}(fact::UmfpackLU{Ts}, b::Vector{Tb}) = fact\convert(Vector{Ts},b)
function (\){Tb<:Complex}(fact::UmfpackLU{Float64}, b::Vector{Tb})
r = fact\[convert(Float64,real(be)) for be in b]
i = fact\[convert(Float64,imag(be)) for be in b]
Complex128[r[k]+im*i[k] for k = 1:length(r)]
end
At_ldiv_B{T<:UMFVTypes}(fact::UmfpackLU{T}, b::Vector{T}) = solve(fact, b, UMFPACK_Aat)
At_ldiv_B{Ts<:UMFVTypes,Tb<:Number}(fact::UmfpackLU{Ts}, b::Vector{Tb}) = fact.'\convert(Vector{Ts},b)
At_ldiv_B{Tb<:Complex}(fact::UmfpackLU{Float64}, b::Vector{Tb}) = fact.'\b
Ac_ldiv_B{T<:UMFVTypes}(fact::UmfpackLU{T}, b::Vector{T}) = solve(fact, b, UMFPACK_At)
Ac_ldiv_B{Ts<:UMFVTypes,Tb<:Number}(fact::UmfpackLU{Ts}, b::Vector{Tb}) = fact'\convert(Vector{Ts},b)
Ac_ldiv_B{Tb<:Complex}(fact::UmfpackLU{Float64}, b::Vector{Tb}) = fact'\b

### Solve directly with matrix

(\)(S::SparseMatrixCSC, b::Vector) = lufact(S) \ b
At_ldiv_B{T<:UMFVTypes}(S::SparseMatrixCSC{T}, b::Vector{T}) = solve(lufact(S), b, UMFPACK_Aat)
function At_ldiv_B{Ts<:UMFVTypes,Tb<:Number}(S::SparseMatrixCSC{Ts}, b::Vector{Tb})
## should be more careful here in case Ts<:Real and Tb<:Complex
At_ldiv_B(S, convert(Vector{Ts}, b))
At_ldiv_B{Ts<:UMFVTypes,Tb<:Number}(S::SparseMatrixCSC{Ts}, b::Vector{Tb}) = At_ldiv_B(S, convert(Vector{Ts}, b))
function At_ldiv_B{Tb<:Complex}(S::SparseMatrixCSC{Float64}, b::Vector{Tb})
r = At_ldiv_B(S, [convert(Float64,real(be)) for be in b])
i = At_ldiv_B(S, [convert(Float64,imag(be)) for be in b])
Complex128[r[k]+im*i[k] for k = 1:length(r)]
end
Ac_ldiv_B{T<:UMFVTypes}(S::SparseMatrixCSC{T}, b::Vector{T}) = solve(lufact(S), b, UMFPACK_At)
function Ac_ldiv_B{Ts<:UMFVTypes,Tb<:Number}(S::SparseMatrixCSC{Ts}, b::Vector{Tb})
## should be more careful here in case Ts<:Real and Tb<:Complex
Ac_ldiv_B(S, convert(Vector{Ts}, b))
Ac_ldiv_B{Ts<:UMFVTypes,Tb<:Number}(S::SparseMatrixCSC{Ts}, b::Vector{Tb}) = Ac_ldiv_B(S, convert(Vector{Ts}, b))
function Ac_ldiv_B{Tb<:Complex}(S::SparseMatrixCSC{Float64}, b::Vector{Tb})
r = Ac_ldiv_B(S, [convert(Float64,real(be)) for be in b])
i = Ac_ldiv_B(S, [convert(Float64,imag(be)) for be in b])
Complex128[r[k]+im*i[k] for k = 1:length(r)]
end

solve(lu::UmfpackLU, b::Vector) = solve(lu, b, UMFPACK_A)
Expand Down
24 changes: 22 additions & 2 deletions test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,34 @@ a116[p, p] = reshape(1:9, 3, 3)
s116[p, p] = reshape(1:9, 3, 3)
@test a116 == s116

# check matrix-vector multiplication
# matrix-vector multiplication (non-square)
for i = 1:5
a = sprand(10, 5, 0.5)
b = rand(5)
@test maximum(abs(a*b - dense(a)*b)) < 100*eps()
end

# check matrix multiplication
# complex matrix-vector multiplication and left-division
for i = 1:5
a = speye(5) + 0.1*sprandn(5, 5, 0.2)
b = randn(5) + im*randn(5)
@test (maximum(abs(a*b - dense(a)*b)) < 100*eps())
@test (maximum(abs(a\b - dense(a)\b)) < 1000*eps())
@test (maximum(abs(a'\b - dense(a')\b)) < 1000*eps())
a = speye(5) + 0.1*sprandn(5, 5, 0.2) + 0.1*im*sprandn(5, 5, 0.2)
b = randn(5)
@test (maximum(abs(a*b - dense(a)*b)) < 100*eps())
@test (maximum(abs(a\b - dense(a)\b)) < 1000*eps())
@test (maximum(abs(a'\b - dense(a')\b)) < 1000*eps())
@test (maximum(abs(a.'\b - dense(a.')\b)) < 1000*eps())
b = randn(5) + im*randn(5)
@test (maximum(abs(a*b - dense(a)*b)) < 100*eps())
@test (maximum(abs(a\b - dense(a)\b)) < 1000*eps())
@test (maximum(abs(a'\b - dense(a')\b)) < 1000*eps())
@test (maximum(abs(a.'\b - dense(a.')\b)) < 1000*eps())
end

# matrix multiplication
for i = 1:5
a = sprand(10, 5, 0.5)
b = sprand(5, 10, 0.1)
Expand Down

0 comments on commit d46936b

Please sign in to comment.