Skip to content

Commit

Permalink
Review annotations and test for allocations in generic matmatmul (Jul…
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Dec 6, 2023
1 parent b5abac4 commit 6a1df3d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 19 deletions.
16 changes: 7 additions & 9 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,7 @@ For custom matrix and vector types, it is recommended to implement
5-argument `mul!` rather than implementing 3-argument `mul!` directly
if possible.
"""
@inline function mul!(C, A, B)
return mul!(C, A, B, true, false)
end
mul!(C, A, B) = mul!(C, A, B, true, false)

"""
mul!(C, A, B, α, β) -> C
Expand Down Expand Up @@ -337,6 +335,7 @@ julia> lmul!(F.Q, B)
lmul!(A, B)

# THE one big BLAS dispatch
# aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
if all(in(('N', 'T', 'C')), (tA, tB))
Expand Down Expand Up @@ -567,7 +566,7 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
end
end

Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
Expand Down Expand Up @@ -607,7 +606,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end

Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasReal}
mA, nA = lapack_size(tA, A)
Expand Down Expand Up @@ -762,8 +761,7 @@ function generic_matmatmul(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
generic_matmatmul!(C, tA, tB, A, B)
end

const tilebufsize = 10800 # Approximately 32k/3

# aggressive const prop makes mixed eltype mul!(C, A, B) invoke _generic_matmatmul! directly
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) =
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)

Expand Down Expand Up @@ -831,7 +829,7 @@ function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
end

Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
if !(size(A) == size(B) == size(C) == (2,2))
Expand Down Expand Up @@ -898,7 +896,7 @@ function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
end

Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
if !(size(A) == size(B) == size(C) == (3,3))
Expand Down
65 changes: 55 additions & 10 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ mul_wrappers = [
@test Base.infer_return_type((Vector{Float64},)) do v
LinearAlgebra.wrap(v, 'N')
end == Vector{Float64}
h(A) = LinearAlgebra.wrap(LinearAlgebra._unwrap(A), LinearAlgebra.wrapper_char(A))
@test @inferred(h(transpose(A))) === transpose(A)
@test @inferred(h(adjoint(A))) === transpose(A)
end

@testset "matrices with zero dimensions" begin
Expand Down Expand Up @@ -131,7 +134,7 @@ end
@test mul!(C, transpose(A), B) == A' * B
@test mul!(C, A, transpose(B)) == A * B'
@test mul!(C, transpose(A), transpose(B)) == A' * B'
@test LinearAlgebra.mul!(C, adjoint(A), transpose(B)) == A' * transpose(B)
@test mul!(C, adjoint(A), transpose(B)) == A' * transpose(B)

# Inplace multiply-add
α = rand(-10:10)
Expand All @@ -147,8 +150,8 @@ end
@test mul!(C0(), adjoint(A), transpose(B), α, β) == α * A' * transpose(B) .+ βC

#test DimensionMismatch for generic_matmatmul
@test_throws DimensionMismatch LinearAlgebra.mul!(C, adjoint(A), transpose(fill(1, 4, 4)))
@test_throws DimensionMismatch LinearAlgebra.mul!(C, adjoint(fill(1, 4, 4)), transpose(B))
@test_throws DimensionMismatch mul!(C, adjoint(A), transpose(fill(1, 4, 4)))
@test_throws DimensionMismatch mul!(C, adjoint(fill(1, 4, 4)), transpose(B))
end
vv = [1, 2]
CC = Matrix{Int}(undef, 2, 2)
Expand Down Expand Up @@ -240,9 +243,9 @@ end
BB = rand(Float64, 6, 6)
CC = zeros(Float64, 6, 6)
for A in (copy(AA), view(AA, 1:6, 1:6)), B in (copy(BB), view(BB, 1:6, 1:6)), C in (copy(CC), view(CC, 1:6, 1:6))
@test LinearAlgebra.mul!(C, transpose(A), transpose(B)) == transpose(A) * transpose(B)
@test LinearAlgebra.mul!(C, A, adjoint(B)) == A * transpose(B)
@test LinearAlgebra.mul!(C, adjoint(A), B) == transpose(A) * B
@test mul!(C, transpose(A), transpose(B)) == transpose(A) * transpose(B)
@test mul!(C, A, adjoint(B)) == A * transpose(B)
@test mul!(C, adjoint(A), B) == transpose(A) * B

# Inplace multiply-add
α = rand(Float64)
Expand All @@ -257,15 +260,57 @@ end
end
end

@testset "allocations in BLAS-mul" begin
for n in (2, 3, 6)
A = rand(Float64, n, n)
B = rand(Float64, n, n)
C = zeros(Float64, n, n)
# gemm
for t in (identity, adjoint, transpose)
At = t(A)
Bt = t(B)
mul!(C, At, B)
@test 0 == @allocations mul!(C, At, B)
mul!(C, A, Bt)
@test 0 == @allocations mul!(C, A, Bt)
mul!(C, At, Bt)
@test 0 == @allocations mul!(C, At, Bt)
end
# syrk/herk
@test 0 == @allocations mul!(C, transpose(A), A)
@test 0 == @allocations mul!(C, adjoint(A), A)
@test 0 == @allocations mul!(C, A, transpose(A))
@test 0 == @allocations mul!(C, A, adjoint(A))
# complex times real
Cc = complex(C)
Ac = complex(A)
for t in (identity, adjoint, transpose)
Bt = t(B)
@test 0 == @allocations mul!(Cc, Ac, Bt)
end
end
end

@testset "mixed Blas-non-Blas matmul" begin
AA = rand(-10:10, 6, 6)
BB = ones(Float64, 6, 6)
CC = zeros(Float64, 6, 6)
for A in (copy(AA), view(AA, 1:6, 1:6)), B in (copy(BB), view(BB, 1:6, 1:6)), C in (copy(CC), view(CC, 1:6, 1:6))
@test LinearAlgebra.mul!(C, A, B) == A * B
@test LinearAlgebra.mul!(C, transpose(A), transpose(B)) == transpose(A) * transpose(B)
@test LinearAlgebra.mul!(C, A, adjoint(B)) == A * transpose(B)
@test LinearAlgebra.mul!(C, adjoint(A), B) == transpose(A) * B
@test mul!(C, A, B) == A * B
@test mul!(C, transpose(A), transpose(B)) == transpose(A) * transpose(B)
@test mul!(C, A, adjoint(B)) == A * transpose(B)
@test mul!(C, adjoint(A), B) == transpose(A) * B
end
end

@testset "allocations in mixed Blas-non-Blas matmul" begin
for n in (2, 3, 6)
A = rand(-10:10, n, n)
B = ones(Float64, n, n)
C = zeros(Float64, n, n)
@test 0 == @allocations mul!(C, A, B)
@test 0 == @allocations mul!(C, A, transpose(B))
@test 0 == @allocations mul!(C, adjoint(A), B)
end
end

Expand Down

0 comments on commit 6a1df3d

Please sign in to comment.