Skip to content

Commit

Permalink
Fix dispatch on sparse multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Mar 7, 2016
1 parent 03c2ef6 commit 1982118
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
4 changes: 3 additions & 1 deletion base/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ using Base.Sort: Forward
using Base.LinAlg: AbstractTriangular, PosDefException

import Base: +, -, *, \, &, |, $, .+, .-, .*, ./, .\, .^, .<, .!=, ==
import Base: A_mul_B!, Ac_mul_B, Ac_mul_B!, At_mul_B, At_mul_B!, At_ldiv_B, Ac_ldiv_B, A_ldiv_B!
import Base: A_mul_B!, Ac_mul_B, Ac_mul_B!, At_mul_B, At_mul_B!
import Base: A_mul_Bc, A_mul_Bt, Ac_mul_Bc, At_mul_Bt
import Base: At_ldiv_B, Ac_ldiv_B, A_ldiv_B!
import Base.LinAlg: At_ldiv_B!, Ac_ldiv_B!

import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
Expand Down
27 changes: 26 additions & 1 deletion base/sparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,24 @@ increment{T<:Integer}(A::AbstractArray{T}) = increment!(copy(A))
## sparse matrix multiplication

function (*){TvA,TiA,TvB,TiB}(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB})
(*)(sppromote(A, B)...)
end
for f in (:A_mul_Bt, :A_mul_Bc,
:At_mul_B, :Ac_mul_B,
:At_mul_Bt, :Ac_mul_Bc)
@eval begin
function ($f){TvA,TiA,TvB,TiB}(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB})
($f)(sppromote(A, B)...)
end
end
end

function sppromote{TvA,TiA,TvB,TiB}(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB})
Tv = promote_type(TvA, TvB)
Ti = promote_type(TiA, TiB)
A = convert(SparseMatrixCSC{Tv,Ti}, A)
B = convert(SparseMatrixCSC{Tv,Ti}, B)
A * B
A, B
end

# In matrix-vector multiplication, the correct orientation of the vector is assumed.
Expand Down Expand Up @@ -109,6 +122,18 @@ end
# http://dl.acm.org/citation.cfm?id=355796

(*){Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) = spmatmul(A,B)
for (f, opA, opB) in ((:A_mul_Bt, :identity, :transpose),
(:A_mul_Bc, :identity, :ctranspose),
(:At_mul_B, :transpose, :identity),
(:Ac_mul_B, :ctranspose, :identity),
(:At_mul_Bt, :transpose, :transpose),
(:Ac_mul_Bc, :ctranspose, :ctranspose))
@eval begin
function ($f){Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti})
spmatmul(($opA)(A), ($opB)(B))
end
end
end

function spmatmul{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti};
sortindices::Symbol = :sortcols)
Expand Down

0 comments on commit 1982118

Please sign in to comment.