diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index b5d79f2aa19d9..bd0566a11b3f2 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -2323,7 +2323,7 @@ sqrt(A::UnitLowerTriangular) = copy(transpose(sqrt(copy(transpose(A))))) # Auxiliary functions for matrix square root # square root of upper triangular or real upper quasitriangular matrix -function sqrt_quasitriu(A0) +function sqrt_quasitriu(A0; blockwidth = eltype(A0) <: Complex ? 512 : 256) n = checksquare(A0) T = eltype(A0) Tr = typeof(sqrt(real(zero(T)))) @@ -2350,7 +2350,7 @@ function sqrt_quasitriu(A0) A = A0 R = zeros(Tc, n, n) end - _sqrt_quasitriu!(R, A) + _sqrt_quasitriu!(R, A; blockwidth=blockwidth, n=n) Rc = eltype(A0) <: Real ? R : complex(R) if A0 isa UpperTriangular return UpperTriangular(Rc) @@ -2361,7 +2361,32 @@ function sqrt_quasitriu(A0) end end -function _sqrt_quasitriu!(R, A) +# in-place recursive sqrt of upper quasi-triangular matrix A from +# Deadman E., Higham N.J., Ralha R. (2013) Blocked Schur Algorithms for Computing the Matrix +# Square Root. Applied Parallel and Scientific Computing. PARA 2012. Lecture Notes in +# Computer Science, vol 7782. https://doi.org/10.1007/978-3-642-36803-5_12 +function _sqrt_quasitriu!(R, A; blockwidth=64, n=checksquare(A)) + if n ≤ blockwidth || !(eltype(R) <: BlasFloat) # base case, perform "point" algorithm + _sqrt_quasitriu_block!(R, A) + else # compute blockwise recursion + split = div(n, 2) + iszero(A[split+1, split]) || (split += 1) # don't split 2x2 diagonal block + r1 = 1:split + r2 = (split + 1):n + n1, n2 = split, n - split + A11, A12, A22 = @views A[r1,r1], A[r1,r2], A[r2,r2] + R11, R12, R22 = @views R[r1,r1], R[r1,r2], R[r2,r2] + # solve diagonal blocks recursively + _sqrt_quasitriu!(R11, A11; blockwidth=blockwidth, n=n1) + _sqrt_quasitriu!(R22, A22; blockwidth=blockwidth, n=n2) + # solve off-diagonal block + R12 .= .- A12 + _sylvester_quasitriu!(R11, R22, R12; blockwidth=blockwidth, nA=n1, nB=n2, raise=false) + end + return R +end + +function _sqrt_quasitriu_block!(R, A) _sqrt_quasitriu_diag_block!(R, A) _sqrt_quasitriu_offdiag_block!(R, A) return R @@ -2514,6 +2539,83 @@ Base.@propagate_inbounds function _sqrt_quasitriu_offdiag_block_2x2!(R, A, i, j) return R end +# solve Sylvester's equation AX + XB = -C using blockwise recursion until the dimension of +# A and B are no greater than blockwidth, based on Algorithm 1 from +# Jonsson I, Kågström B. Recursive blocked algorithms for solving triangular systems— +# Part I: one-sided and coupled Sylvester-type matrix equations. (2002) ACM Trans Math Softw. +# 28(4), https://doi.org/10.1145/592843.592845. +# specify raise=false to avoid breaking the recursion if a LAPACKException is thrown when +# computing one of the blocks. +function _sylvester_quasitriu!(A, B, C; blockwidth=64, nA=checksquare(A), nB=checksquare(B), raise=true) + if 1 ≤ nA ≤ blockwidth && 1 ≤ nB ≤ blockwidth + _sylvester_quasitriu_base!(A, B, C; raise=raise) + elseif nA ≥ 2nB ≥ 2 + _sylvester_quasitriu_split1!(A, B, C; blockwidth=blockwidth, nA=nA, nB=nB, raise=raise) + elseif nB ≥ 2nA ≥ 2 + _sylvester_quasitriu_split2!(A, B, C; blockwidth=blockwidth, nA=nA, nB=nB, raise=raise) + else + _sylvester_quasitriu_splitall!(A, B, C; blockwidth=blockwidth, nA=nA, nB=nB, raise=raise) + end + return C +end +function _sylvester_quasitriu_base!(A, B, C; raise=true) + try + _, scale = LAPACK.trsyl!('N', 'N', A, B, C) + rmul!(C, -inv(scale)) + catch e + if !(e isa LAPACKException) || raise + throw(e) + end + end + return C +end +function _sylvester_quasitriu_split1!(A, B, C; nA=checksquare(A), kwargs...) + iA = div(nA, 2) + iszero(A[iA + 1, iA]) || (iA += 1) # don't split 2x2 diagonal block + rA1, rA2 = 1:iA, (iA + 1):nA + nA1, nA2 = iA, nA-iA + A11, A12, A22 = @views A[rA1,rA1], A[rA1,rA2], A[rA2,rA2] + C1, C2 = @views C[rA1,:], C[rA2,:] + _sylvester_quasitriu!(A22, B, C2; nA=nA2, kwargs...) + mul!(C1, A12, C2, true, true) + _sylvester_quasitriu!(A11, B, C1; nA=nA1, kwargs...) + return C +end +function _sylvester_quasitriu_split2!(A, B, C; nB=checksquare(B), kwargs...) + iB = div(nB, 2) + iszero(B[iB + 1, iB]) || (iB += 1) # don't split 2x2 diagonal block + rB1, rB2 = 1:iB, (iB + 1):nB + nB1, nB2 = iB, nB-iB + B11, B12, B22 = @views B[rB1,rB1], B[rB1,rB2], B[rB2,rB2] + C1, C2 = @views C[:,rB1], C[:,rB2] + _sylvester_quasitriu!(A, B11, C1; nB=nB1, kwargs...) + mul!(C2, C1, B12, true, true) + _sylvester_quasitriu!(A, B22, C2; nB=nB2, kwargs...) + return C +end +function _sylvester_quasitriu_splitall!(A, B, C; nA=checksquare(A), nB=checksquare(B), kwargs...) + iA = div(nA, 2) + iszero(A[iA + 1, iA]) || (iA += 1) # don't split 2x2 diagonal block + iB = div(nB, 2) + iszero(B[iB + 1, iB]) || (iB += 1) # don't split 2x2 diagonal block + rA1, rA2 = 1:iA, (iA + 1):nA + nA1, nA2 = iA, nA-iA + rB1, rB2 = 1:iB, (iB + 1):nB + nB1, nB2 = iB, nB-iB + A11, A12, A22 = @views A[rA1,rA1], A[rA1,rA2], A[rA2,rA2] + B11, B12, B22 = @views B[rB1,rB1], B[rB1,rB2], B[rB2,rB2] + C11, C21, C12, C22 = @views C[rA1,rB1], C[rA2,rB1], C[rA1,rB2], C[rA2,rB2] + _sylvester_quasitriu!(A22, B11, C21; nA=nA2, nB=nB1, kwargs...) + mul!(C11, A12, C21, true, true) + _sylvester_quasitriu!(A11, B11, C11; nA=nA1, nB=nB1, kwargs...) + mul!(C22, C21, B12, true, true) + _sylvester_quasitriu!(A22, B22, C22; nA=nA2, nB=nB2, kwargs...) + mul!(C12, A12, C22, true, true) + mul!(C12, C11, B12, true, true) + _sylvester_quasitriu!(A11, B22, C12; nA=nA1, nB=nB2, kwargs...) + return C +end + # End of auxiliary functions for matrix square root # Generic eigensystems diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index 9681b40a22075..6950d7a956b87 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -513,6 +513,41 @@ Atu = UnitUpperTriangular([1 1 2; 0 1 2; 0 0 1]) @test typeof(sqrt(Atu)[1,1]) <: Real @test typeof(sqrt(complex(Atu))[1,1]) <: Complex +@testset "matrix square root quasi-triangular blockwise" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + A = schur(rand(T, 100, 100)^2).T + @test LinearAlgebra.sqrt_quasitriu(A; blockwidth=16)^2 ≈ A + end + n = 256 + A = rand(ComplexF64, n, n) + U = schur(A).T + Ubig = Complex{BigFloat}.(U) + @test LinearAlgebra.sqrt_quasitriu(U; blockwidth=64) ≈ LinearAlgebra.sqrt_quasitriu(Ubig; blockwidth=64) +end + +@testset "sylvester quasi-triangular blockwise" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64), m in (15, 40), n in (15, 45) + A = schur(rand(T, m, m)).T + B = schur(rand(T, n, n)).T + C = randn(T, m, n) + Ccopy = copy(C) + X = LinearAlgebra._sylvester_quasitriu!(A, B, C; blockwidth=16) + @test X === C + @test A * X + X * B ≈ -Ccopy + + @testset "test raise=false does not break recursion" begin + Az = zero(A) + Bz = zero(B) + C2 = copy(Ccopy) + @test_throws LAPACKException LinearAlgebra._sylvester_quasitriu!(Az, Bz, C2; blockwidth=16) + m == n || @test any(C2 .== Ccopy) # recursion broken + C3 = copy(Ccopy) + X3 = LinearAlgebra._sylvester_quasitriu!(Az, Bz, C3; blockwidth=16, raise=false) + @test !any(X3 .== Ccopy) # recursion not broken + end + end +end + @testset "check matrix logarithm type-inferrable" for elty in (Float32,Float64,ComplexF32,ComplexF64) A = UpperTriangular(exp(triu(randn(elty, n, n)))) @inferred Union{typeof(A),typeof(complex(A))} log(A)