Skip to content

Commit 6e5ea12

Browse files
authored
sqrt, cbrt and log for dense diagonal matrices (#1156)
This PR improves performance by only applying the functions to the diagonal elements: ```julia julia> A = diagm(0=>ones(100)); julia> @Btime log($A); 364.163 μs (22 allocations: 401.62 KiB) # master 13.528 μs (7 allocations: 80.02 KiB) # this PR ``` Similar improvements for `sqrt` and `cbrt` as well.
1 parent 959d985 commit 6e5ea12

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

src/dense.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,9 @@ julia> log(A)
890890
"""
891891
function log(A::AbstractMatrix)
892892
# If possible, use diagonalization
893-
if ishermitian(A)
893+
if isdiag(A)
894+
return applydiagonal(log, A)
895+
elseif ishermitian(A)
894896
logHermA = log(Hermitian(A))
895897
return ishermitian(logHermA) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
896898
elseif istriu(A)
@@ -969,7 +971,9 @@ sqrt(::AbstractMatrix)
969971

970972
function sqrt(A::AbstractMatrix{T}) where {T<:Union{Real,Complex}}
971973
if checksquare(A) == 0
972-
return copy(A)
974+
return copy(float(A))
975+
elseif isdiag(A)
976+
return applydiagonal(sqrt, A)
973977
elseif ishermitian(A)
974978
sqrtHermA = sqrt(Hermitian(A))
975979
return ishermitian(sqrtHermA) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
@@ -1035,7 +1039,9 @@ true
10351039
"""
10361040
function cbrt(A::AbstractMatrix{<:Real})
10371041
if checksquare(A) == 0
1038-
return copy(A)
1042+
return copy(float(A))
1043+
elseif isdiag(A)
1044+
return applydiagonal(cbrt, A)
10391045
elseif issymmetric(A)
10401046
return cbrt(Symmetric(A, :U))
10411047
else

test/dense.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,7 @@ end
817817

818818
A13 = convert(Matrix{elty}, [2 0; 0 2])
819819
@test typeof(log(A13)) == Array{elty, 2}
820+
@test exp(log(A13)) log(exp(A13)) A13
820821

821822
T = elty == Float64 ? Symmetric : Hermitian
822823
@test typeof(log(T(A13))) == T{elty, Array{elty, 2}}
@@ -968,6 +969,10 @@ end
968969
@test typeof(sqrt(A8)) == Matrix{elty}
969970
end
970971
end
972+
@testset "sqrt for diagonal" begin
973+
A = diagm(0 => [1, 2, 3])
974+
@test sqrt(A)^2 A
975+
end
971976

972977
@testset "issue #40141" begin
973978
x = [-1 -eps() 0 0; eps() -1 0 0; 0 0 -1 -eps(); 0 0 eps() -1]
@@ -1280,6 +1285,7 @@ end
12801285
T = cbrt(Symmetric(S,:U))
12811286
@test T*T*T S
12821287
@test eltype(S) == eltype(T)
1288+
@test cbrt(Array(Symmetric(S,:U))) == T
12831289
# Real valued symmetric
12841290
S = (A -> (A+A')/2)(randn(N,N))
12851291
T = cbrt(Symmetric(S,:L))
@@ -1300,6 +1306,16 @@ end
13001306
T = cbrt(A)
13011307
@test T*T*T A
13021308
@test eltype(A) == eltype(T)
1309+
@testset "diagonal" begin
1310+
A = diagm(0 => [1, 2, 3])
1311+
@test cbrt(A)^3 A
1312+
end
1313+
@testset "empty" begin
1314+
A = Matrix{Float64}(undef, 0, 0)
1315+
@test cbrt(A) == A
1316+
A = Matrix{Int}(undef, 0, 0)
1317+
@test cbrt(A) isa Matrix{Float64}
1318+
end
13031319
end
13041320

13051321
@testset "tr" begin

0 commit comments

Comments
 (0)