Skip to content

Commit

Permalink
Make matrix multiplication work for more types (JuliaLang#18218)
Browse files Browse the repository at this point in the history
* Make matrix multiplication work for more types

Currently it is assumed that the type of a sum of x::T and y::T
is T but this may not be the case

* Remove arithtype in matmul and deprecate it
  • Loading branch information
blegat authored and andreasnoack committed Oct 24, 2016
1 parent f80ea1c commit 1efe487
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 18 deletions.
18 changes: 18 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1058,4 +1058,22 @@ function reduced_dims0(dims::Dims, region)
map(last, reduced_dims0(map(n->OneTo(n), dims), region))
end

# #18218
eval(Base.LinAlg, quote
function arithtype(T)
depwarn(string("arithtype is now deprecated. If you were using it inside a ",
"promote_op call, use promote_op(LinAlg.matprod, Ts...) instead. Otherwise, ",
"if you need its functionality, consider defining it locally."),
:arithtype)
T
end
function arithtype(::Type{Bool})
depwarn(string("arithtype is now deprecated. If you were using it inside a ",
"promote_op call, use promote_op(LinAlg.matprod, Ts...) instead. Otherwise, ",
"if you need its functionality, consider defining it locally."),
:arithtype)
Int
end
end)

# End deprecations scheduled for 0.6
35 changes: 17 additions & 18 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

# matmul.jl: Everything to do with dense matrix multiplication

arithtype(T) = T
arithtype(::Type{Bool}) = Int
matprod(x, y) = x*y + x*y

# multiply by diagonal matrix as vector
function scale!(C::AbstractMatrix, A::AbstractMatrix, b::AbstractVector)
Expand Down Expand Up @@ -76,11 +75,11 @@ At_mul_B{T<:BlasComplex}(x::StridedVector{T}, y::StridedVector{T}) = [BLAS.dotu(

# Matrix-vector multiplication
function (*){T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_B!(similar(x, TS, size(A,1)), A, convert(AbstractVector{TS}, x))
end
function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_B!(similar(x,TS,size(A,1)),A,x)
end
(*)(A::AbstractVector, B::AbstractMatrix) = reshape(A,length(A),1)*B
Expand All @@ -99,22 +98,22 @@ end
A_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'N', A, x)

function At_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_B!(similar(x,TS,size(A,2)), A, convert(AbstractVector{TS}, x))
end
function At_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_B!(similar(x,TS,size(A,2)), A, x)
end
At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x)
At_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'T', A, x)

function Ac_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
Ac_mul_B!(similar(x,TS,size(A,2)),A,convert(AbstractVector{TS},x))
end
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
Ac_mul_B!(similar(x,TS,size(A,2)), A, x)
end

Expand All @@ -132,7 +131,7 @@ Ac_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_m
Matrix multiplication.
"""
function (*){T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_B!(similar(B, TS, (size(A,1), size(B,2))), A, B)
end
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
Expand Down Expand Up @@ -166,14 +165,14 @@ julia> Y
A_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'N', A, B)

function At_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)
At_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'T', 'N', A, B)

function A_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_Bt!(similar(B, TS, (size(A,1), size(B,1))), A, B)
end
A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)
Expand All @@ -190,7 +189,7 @@ end
A_mul_Bt!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'T', A, B)

function At_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractVecOrMat{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_Bt!(similar(B, TS, (size(A,2), size(B,1))), A, B)
end
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'T', 'T', A, B)
Expand All @@ -199,7 +198,7 @@ At_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generi
Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B(A, B)
Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = At_mul_B!(C, A, B)
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
Ac_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)
Expand All @@ -208,14 +207,14 @@ Ac_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic
A_mul_Bc{T<:BlasFloat,S<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt(A, B)
A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{S}) = A_mul_Bt!(C, A, B)
function A_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_Bc!(similar(B,TS,(size(A,1),size(B,1))),A,B)
end
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)
A_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'C', A, B)

Ac_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) =
Ac_mul_Bc!(similar(B, promote_op(*, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!(similar(B, promote_op(matprod, T, S), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'C', 'C', A, B)
Ac_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'C', A, B)
Ac_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'T', A, B)
Expand Down Expand Up @@ -448,7 +447,7 @@ end
function generic_matmatmul{T,S}(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, promote_op(*, arithtype(T), arithtype(S)), mA, nB)
C = similar(B, promote_op(matprod, T, S), mA, nB)
generic_matmatmul!(C, tA, tB, A, B)
end

Expand Down Expand Up @@ -642,7 +641,7 @@ end

# multiply 2x2 matrices
function matmul2x2{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
matmul2x2!(similar(B, promote_op(*, T, S), 2, 2), tA, tB, A, B)
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
end

function matmul2x2!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
Expand Down Expand Up @@ -671,7 +670,7 @@ end

# Multiply 3x3 matrices
function matmul3x3{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
matmul3x3!(similar(B, promote_op(*, T, S), 3, 3), tA, tB, A, B)
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
end

function matmul3x3!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
Expand Down
27 changes: 27 additions & 0 deletions test/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,30 @@ let
@test_throws DimensionMismatch A_mul_B!(full43, full43, tri44)
end
end

# #18218
module TestPR18218
using Base.Test
import Base.*, Base.+, Base.zero
immutable TypeA
x::Int
end
Base.convert(::Type{TypeA}, x::Int) = TypeA(x)
immutable TypeB
x::Int
end
immutable TypeC
x::Int
end
Base.convert(::Type{TypeC}, x::Int) = TypeC(x)
zero(c::TypeC) = TypeC(0)
zero(::Type{TypeC}) = TypeC(0)
(*)(x::Int, a::TypeA) = TypeB(x*a.x)
(*)(a::TypeA, x::Int) = TypeB(a.x*x)
(+)(a::Union{TypeB,TypeC}, b::Union{TypeB,TypeC}) = TypeC(a.x+b.x)
A = TypeA[1 2; 3 4]
b = [1, 2]
d = A * b
@test typeof(d) == Vector{TypeC}
@test d == TypeC[5, 11]
end

0 comments on commit 1efe487

Please sign in to comment.