Skip to content

Commit 5dbf0a3

Browse files
araujomsjishnub
andauthored
type stability of matrix functions (#1360)
I've fixed the type stability of matrix functions as far as possible without changing the API as discussed in #1130. Also added tests. Closes #460. --------- Co-authored-by: Jishnu Bhattacharya <jishnub.github@gmail.com>
1 parent 045ee5c commit 5dbf0a3

File tree

5 files changed

+126
-61
lines changed

5 files changed

+126
-61
lines changed

src/dense.jl

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ function schurpow(A::AbstractMatrix, p)
592592
end
593593

594594
# if A has nonpositive real eigenvalues, retmat is a nonprincipal matrix power.
595-
if isreal(retmat)
595+
if eltype(A) <: Real && isreal(retmat)
596596
return real(retmat)
597597
else
598598
return retmat
@@ -602,26 +602,35 @@ function (^)(A::AbstractMatrix{T}, p::Real) where T
602602
checksquare(A)
603603
# Quicker return if A is diagonal
604604
if isdiag(A)
605-
TT = promote_op(^, T, typeof(p))
606-
retmat = copymutable_oftype(A, TT)
607-
for i in diagind(retmat, IndexStyle(retmat))
608-
retmat[i] = retmat[i] ^ p
605+
if T <: Real && any(<(0), diagview(A))
606+
return applydiagonal(x -> complex(x)^p, A)
607+
else
608+
return applydiagonal(x -> x^p, A)
609609
end
610-
return retmat
611610
end
612611

613612
# For integer powers, use power_by_squaring
614613
isinteger(p) && return integerpow(A, p)
615614

616615
# If possible, use diagonalization
617616
if ishermitian(A)
618-
return (Hermitian(A)^p)
617+
return _safe_parent(Hermitian(A)^p)
619618
end
620619

621620
# Otherwise, use Schur decomposition
622621
return schurpow(A, p)
623622
end
624623

624+
function _safe_parent(fA)
625+
parentfA = parent(fA)
626+
if isa(fA, Hermitian) || isa(fA, Symmetric{<:Real})
627+
return copytri_maybe_inplace(parentfA, 'U', true)
628+
elseif isa(fA, Symmetric)
629+
return copytri_maybe_inplace(parentfA, 'U')
630+
else
631+
return fA
632+
end
633+
end
625634
"""
626635
^(A::AbstractMatrix, p::Number)
627636
@@ -918,15 +927,13 @@ julia> log(A)
918927
function log(A::AbstractMatrix)
919928
# If possible, use diagonalization
920929
if isdiag(A) && eltype(A) <: Union{Real,Complex}
921-
if eltype(A) <: Real && all(>=(0), diagview(A))
922-
return applydiagonal(log, A)
930+
if eltype(A) <: Real && any(<(0), diagview(A))
931+
return applydiagonal(log complex, A)
923932
else
924-
return applydiagonal(logcomplex, A)
933+
return applydiagonal(log, A)
925934
end
926935
elseif ishermitian(A)
927-
logHermA = log(Hermitian(A))
928-
PH = parent(logHermA)
929-
return ishermitian(logHermA) ? copytri_maybe_inplace(PH, 'U', true) : PH
936+
return _safe_parent(log(Hermitian(A)))
930937
elseif istriu(A)
931938
return triu!(parent(log(UpperTriangular(A))))
932939
elseif isreal(A)
@@ -1004,13 +1011,14 @@ sqrt(::AbstractMatrix)
10041011
function sqrt(A::AbstractMatrix{T}) where {T<:Union{Real,Complex}}
10051012
if checksquare(A) == 0
10061013
return copy(float(A))
1007-
elseif isdiag(A) && (T <: Complex || all(x -> x zero(x), diagview(A)))
1008-
# Real Diagonal sqrt requires each diagonal element to be positive
1009-
return applydiagonal(sqrt, A)
1014+
elseif isdiag(A)
1015+
if T <: Real && any(<(0), diagview(A))
1016+
return applydiagonal(sqrt complex, A)
1017+
else
1018+
return applydiagonal(sqrt, A)
1019+
end
10101020
elseif ishermitian(A)
1011-
sqrtHermA = sqrt(Hermitian(A))
1012-
PS = parent(sqrtHermA)
1013-
return ishermitian(sqrtHermA) ? copytri_maybe_inplace(PS, 'U', true) : PS
1021+
return _safe_parent(sqrt(Hermitian(A)))
10141022
elseif istriu(A)
10151023
return triu!(parent(sqrt(UpperTriangular(A))))
10161024
elseif isreal(A)
@@ -1044,7 +1052,7 @@ sqrt(A::TransposeAbsMat) = transpose(sqrt(parent(A)))
10441052
Computes the real-valued cube root of a real-valued matrix `A`. If `T = cbrt(A)`, then
10451053
we have `T*T*T ≈ A`, see example given below.
10461054
1047-
If `A` is symmetric, i.e., of type `HermOrSym{<:Real}`, then ([`eigen`](@ref)) is used to
1055+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
10481056
find the cube root. Otherwise, a specialized version of the p-th root algorithm [^S03] is
10491057
utilized, which exploits the real-valued Schur decomposition ([`schur`](@ref))
10501058
to compute the cube root.
@@ -1077,7 +1085,7 @@ function cbrt(A::AbstractMatrix{<:Real})
10771085
elseif isdiag(A)
10781086
return applydiagonal(cbrt, A)
10791087
elseif issymmetric(A)
1080-
return cbrt(Symmetric(A, :U))
1088+
return copytri_maybe_inplace(parent(cbrt(Symmetric(A))), 'U')
10811089
else
10821090
S = schur(A)
10831091
return S.Z * _cbrt_quasi_triu!(S.T) * S.Z'
@@ -1118,7 +1126,7 @@ end
11181126
11191127
Compute the matrix cosine of a square matrix `A`.
11201128
1121-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1129+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
11221130
compute the cosine. Otherwise, the cosine is determined by calling [`exp`](@ref).
11231131
11241132
# Examples
@@ -1160,7 +1168,7 @@ end
11601168
11611169
Compute the matrix sine of a square matrix `A`.
11621170
1163-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1171+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
11641172
compute the sine. Otherwise, the sine is determined by calling [`exp`](@ref).
11651173
11661174
# Examples
@@ -1265,7 +1273,7 @@ end
12651273
12661274
Compute the matrix tangent of a square matrix `A`.
12671275
1268-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1276+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
12691277
compute the tangent. Otherwise, the tangent is determined by calling [`exp`](@ref).
12701278
12711279
# Examples
@@ -1357,7 +1365,7 @@ _subadd!!(X, Y) = X - Y, X + Y
13571365
13581366
Compute the inverse matrix cosine of a square matrix `A`.
13591367
1360-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1368+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
13611369
compute the inverse cosine. Otherwise, the inverse cosine is determined by using
13621370
[`log`](@ref) and [`sqrt`](@ref). For the theory and logarithmic formulas used to compute
13631371
this function, see [^AH16_1].
@@ -1376,9 +1384,7 @@ function acos(A::AbstractMatrix)
13761384
if isdiag(A)
13771385
return applydiagonal(acos, A)
13781386
elseif ishermitian(A)
1379-
acosHermA = acos(Hermitian(A))
1380-
P = parent(acosHermA)
1381-
return isa(acosHermA, Hermitian) ? copytri_maybe_inplace(P, 'U', true) : P
1387+
return _safe_parent(acos(Hermitian(A)))
13821388
end
13831389
SchurF = Schur{Complex}(schur(A))
13841390
U = UpperTriangular(SchurF.T)
@@ -1391,7 +1397,7 @@ end
13911397
13921398
Compute the inverse matrix sine of a square matrix `A`.
13931399
1394-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1400+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
13951401
compute the inverse sine. Otherwise, the inverse sine is determined by using [`log`](@ref)
13961402
and [`sqrt`](@ref). For the theory and logarithmic formulas used to compute this function,
13971403
see [^AH16_2].
@@ -1425,7 +1431,7 @@ end
14251431
14261432
Compute the inverse matrix tangent of a square matrix `A`.
14271433
1428-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1434+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
14291435
compute the inverse tangent. Otherwise, the inverse tangent is determined by using
14301436
[`log`](@ref). For the theory and logarithmic formulas used to compute this function, see
14311437
[^AH16_3].
@@ -1436,8 +1442,8 @@ compute the inverse tangent. Otherwise, the inverse tangent is determined by usi
14361442
```julia-repl
14371443
julia> atan(tan([0.5 0.1; -0.2 0.3]))
14381444
2×2 Matrix{ComplexF64}:
1439-
0.5+1.38778e-17im 0.1-2.77556e-17im
1440-
-0.2+6.93889e-17im 0.3-4.16334e-17im
1445+
0.5 0.1
1446+
-0.2 0.3
14411447
```
14421448
"""
14431449
function atan(A::AbstractMatrix)
@@ -1450,7 +1456,12 @@ function atan(A::AbstractMatrix)
14501456
SchurF = Schur{Complex}(schur(A))
14511457
U = im * UpperTriangular(SchurF.T)
14521458
R = triu!(parent(log((I + U) / (I - U)) / 2im))
1453-
return SchurF.Z * R * SchurF.Z'
1459+
retmat = SchurF.Z * R * SchurF.Z'
1460+
if eltype(A) <: Real
1461+
return real(retmat)
1462+
else
1463+
return retmat
1464+
end
14541465
end
14551466

14561467
"""
@@ -1465,9 +1476,7 @@ function acosh(A::AbstractMatrix)
14651476
if isdiag(A)
14661477
return applydiagonal(acosh, A)
14671478
elseif ishermitian(A)
1468-
acoshHermA = acosh(Hermitian(A))
1469-
P = parent(acoshHermA)
1470-
return isa(acoshHermA, Hermitian) ? copytri_maybe_inplace(P, 'U', true) : P
1479+
return _safe_parent(acosh(Hermitian(A)))
14711480
end
14721481
SchurF = Schur{Complex}(schur(A))
14731482
U = UpperTriangular(SchurF.T)
@@ -1493,7 +1502,12 @@ function asinh(A::AbstractMatrix)
14931502
SchurF = Schur{Complex}(schur(A))
14941503
U = UpperTriangular(SchurF.T)
14951504
R = triu!(parent(log(U + sqrt(I + U^2))))
1496-
return SchurF.Z * R * SchurF.Z'
1505+
retmat = SchurF.Z * R * SchurF.Z'
1506+
if eltype(A) <: Real
1507+
return real(retmat)
1508+
else
1509+
return retmat
1510+
end
14971511
end
14981512

14991513
"""
@@ -1508,8 +1522,7 @@ function atanh(A::AbstractMatrix)
15081522
if isdiag(A)
15091523
return applydiagonal(atanh, A)
15101524
elseif ishermitian(A)
1511-
P = parent(atanh(Hermitian(A)))
1512-
return copytri_maybe_inplace(P, 'U', true)
1525+
return _safe_parent(atanh(Hermitian(A)))
15131526
end
15141527
SchurF = Schur{Complex}(schur(A))
15151528
U = UpperTriangular(SchurF.T)

src/symmetric.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ function ^(A::SymSymTri{<:Complex}, p::Real)
869869
return Symmetric(schurpow(A, p))
870870
end
871871

872-
for func in (:cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
872+
for func in (:cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :cbrt)
873873
@eval begin
874874
function ($func)(A::SelfAdjoint)
875875
F = eigen(A)
@@ -892,7 +892,7 @@ function cis(A::SelfAdjoint)
892892
return nonhermitianwrappertype(A)(retmat)
893893
end
894894

895-
for func in (:acos, :asin)
895+
for func in (:acos, :asin, :atanh)
896896
@eval begin
897897
function ($func)(A::SelfAdjoint)
898898
F = eigen(A)

src/triangular.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1987,7 +1987,7 @@ function powm!(A0::UpperTriangular, p::Real)
19871987
A[i, i] = -A[i, i]
19881988
end
19891989
# Compute the Padé approximant
1990-
c = 0.5 * (p - m) / (2 * m - 1)
1990+
c = (p - m) / (4 * m - 2)
19911991
triu!(A)
19921992
S = c * A
19931993
Stmp = similar(S)

test/dense.jl

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -639,8 +639,8 @@ end
639639
sinA1 = convert(Matrix{elty}, [0.2865568596627417 -1.107751980582015 -0.13772915374386513;
640640
-0.6227405671629401 0.2176922827908092 -0.5538759902910078;
641641
-0.6227405671629398 -0.6916051440348725 0.3554214365346742])
642-
@test @inferred(cos(A1)) cosA1
643-
@test @inferred(sin(A1)) sinA1
642+
@test cos(A1) cosA1
643+
@test sin(A1) sinA1
644644

645645
cosA2 = convert(Matrix{elty}, [-0.6331745163802187 0.12878366262380136 -0.17304181968301532;
646646
0.12878366262380136 -0.5596234510748788 0.5210483146041339;
@@ -663,22 +663,22 @@ end
663663

664664
# Identities
665665
for (i, A) in enumerate((A1, A2, A3, A4, A5))
666-
@test @inferred(sincos(A)) == (sin(A), cos(A))
666+
@test sincos(A) == (sin(A), cos(A))
667667
@test cos(A)^2 + sin(A)^2 Matrix(I, size(A))
668668
@test cos(A) cos(-A)
669669
@test sin(A) -sin(-A)
670-
@test @inferred(tan(A)) sin(A) / cos(A)
670+
@test tan(A) sin(A) / cos(A)
671671

672672
@test cos(A) real(exp(im*A))
673673
@test sin(A) imag(exp(im*A))
674674
@test cos(A) real(cis(A))
675675
@test sin(A) imag(cis(A))
676-
@test @inferred(cis(A)) cos(A) + im * sin(A)
676+
@test cis(A) cos(A) + im * sin(A)
677677

678-
@test @inferred(cosh(A)) 0.5 * (exp(A) + exp(-A))
679-
@test @inferred(sinh(A)) 0.5 * (exp(A) - exp(-A))
680-
@test @inferred(cosh(A)) cosh(-A)
681-
@test @inferred(sinh(A)) -sinh(-A)
678+
@test cosh(A) 0.5 * (exp(A) + exp(-A))
679+
@test sinh(A) 0.5 * (exp(A) - exp(-A))
680+
@test cosh(A) cosh(-A)
681+
@test sinh(A) -sinh(-A)
682682

683683
# Some of the following identities fail for A3, A4 because the matrices are singular
684684
if i in (1, 2, 5)
@@ -687,7 +687,7 @@ end
687687
@test @inferred(cot(A)) inv(tan(A))
688688
@test @inferred(sech(A)) inv(cosh(A))
689689
@test @inferred(csch(A)) inv(sinh(A))
690-
@test @inferred(coth(A)) inv(@inferred tanh(A))
690+
@test @inferred(coth(A)) inv(tanh(A))
691691
end
692692
# The following identities fail for A1, A2 due to rounding errors;
693693
# probably needs better algorithm for the general case
@@ -904,11 +904,6 @@ end
904904
end
905905
end
906906

907-
@testset "matrix logarithm is type-inferable" for elty in (Float32,Float64,ComplexF32,ComplexF64)
908-
A1 = randn(elty, 4, 4)
909-
@inferred Union{Matrix{elty},Matrix{complex(elty)}} log(A1)
910-
end
911-
912907
@testset "Additional matrix square root tests" for elty in (Float64, ComplexF64)
913908
A11 = convert(Matrix{elty}, [3 2; -5 -3])
914909
@test sqrt(A11)^2 A11
@@ -1073,7 +1068,7 @@ end
10731068
@test lyap(1.0+2.0im, 3.0+4.0im) == -1.5 - 2.0im
10741069
end
10751070

1076-
@testset "$elty Matrix to real power" for elty in (Float64, ComplexF64)
1071+
@testset "$elty Matrix to real power" for elty in (Float32, Float64, ComplexF32, ComplexF64)
10771072
# Tests proposed at Higham, Deadman: Testing Matrix Function Algorithms Using Identities, March 2014
10781073
#Aa : only positive real eigenvalues
10791074
Aa = convert(Matrix{elty}, [5 4 2 1; 0 1 -1 -1; -1 -1 3 0; 1 1 -1 2])
@@ -1099,7 +1094,13 @@ end
10991094
ADi += [im 0; 0 im]
11001095
end
11011096

1102-
for A in (Aa, Ab, Ac, Ad, Ah, ADi)
1097+
#ADin : negative Diagonal Matrix
1098+
ADin = convert(Matrix{elty}, [-3 0; 0 3])
1099+
if elty <: LinearAlgebra.BlasComplex
1100+
ADin += [im 0; 0 im]
1101+
end
1102+
1103+
for A in (Aa, Ab, Ac, Ad, Ah, ADi, ADin)
11031104
@test A^(1/2) sqrt(A)
11041105
@test A^(-1/2) inv(sqrt(A))
11051106
@test A^(3/4) sqrt(A) * sqrt(sqrt(A))
@@ -1112,7 +1113,7 @@ end
11121113
end
11131114

11141115
Tschurpow = Union{Matrix{real(elty)}, Matrix{complex(elty)}}
1115-
@test (@inferred Tschurpow LinearAlgebra.schurpow(Aa, 2.0)) Aa^2
1116+
@test (@inferred Tschurpow LinearAlgebra.schurpow(Aa, real(elty)(2.0))) Aa^2
11161117
end
11171118

11181119
@testset "BigFloat triangular real power" begin

0 commit comments

Comments
 (0)