Skip to content

Commit 44b6bc1

Browse files
araujomsstevengj
andauthored
slight optimization of exp, sqrt, ^ for positive semidefinite output (#1359)
The speedup is small, my benchmarks show 2-5%, but since these functions are so often used I think it's worth it. Conversely, the same optimization applies to `cosh`, `acosh`, `acos`, and the positive part of other functions but I didn't do it because I don't think the speedup is worth the complexity. --------- Co-authored-by: Steven G. Johnson <stevenj@alum.mit.edu>
1 parent 8a1ecfb commit 44b6bc1

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

src/symmetric.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,12 @@ function svdvals!(A::RealHermSymComplexHerm)
833833
return sort!(vals, rev = true)
834834
end
835835

836+
#computes U * Diagonal(abs2.(v)) * U'
837+
function _psd_spectral_product(v, U)
838+
Uv = U * Diagonal(v)
839+
return Uv * Uv' # often faster than generic matmul by calling BLAS.herk
840+
end
841+
836842
# Matrix functions
837843
^(A::SymSymTri{<:Complex}, p::Integer) = sympow(A, p)
838844
^(A::SelfAdjoint, p::Integer) = sympow(A, p)
@@ -848,7 +854,8 @@ function ^(A::SelfAdjoint, p::Real)
848854
isinteger(p) && return integerpow(A, p)
849855
F = eigen(A)
850856
if all-> λ 0, F.values)
851-
retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors'
857+
rootpower = map-> λ^(p / 2), F.values)
858+
retmat = _psd_spectral_product(rootpower, F.vectors)
852859
return wrappertype(A)(retmat)
853860
else
854861
retmat = (F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors'
@@ -860,7 +867,7 @@ function ^(A::SymSymTri{<:Complex}, p::Real)
860867
return Symmetric(schurpow(A, p))
861868
end
862869

863-
for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
870+
for func in (:cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
864871
@eval begin
865872
function ($func)(A::SelfAdjoint)
866873
F = eigen(A)
@@ -870,6 +877,13 @@ for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh,
870877
end
871878
end
872879

880+
function exp(A::SelfAdjoint)
881+
F = eigen(A)
882+
rootexp = map-> exp/ 2), F.values)
883+
retmat = _psd_spectral_product(rootexp, F.vectors)
884+
return wrappertype(A)(retmat)
885+
end
886+
873887
function cis(A::SelfAdjoint)
874888
F = eigen(A)
875889
retmat = F.vectors .* cis.(F.values') * F.vectors'
@@ -929,7 +943,8 @@ function sqrt(A::SelfAdjoint; rtol = eps(real(float(eltype(A)))) * size(A, 1))
929943
F = eigen(A)
930944
λ₀ = -maximum(abs, F.values) * rtol # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
931945
if all-> λ λ₀, F.values)
932-
retmat = (F.vectors * Diagonal(sqrt.(max.(0, F.values)))) * F.vectors'
946+
rootroot = map-> λ < 0 ? zero(λ) : fourthroot(λ), F.values)
947+
retmat = _psd_spectral_product(rootroot, F.vectors)
933948
return wrappertype(A)(retmat)
934949
else
935950
retmat = (F.vectors * Diagonal(sqrt.(complex.(F.values)))) * F.vectors'

0 commit comments

Comments
 (0)