Skip to content

Commit e37e98a

Browse files
authored
Transpose and Adjoint support for exp, log and sqrt functions (#39373)
1 parent ca6df85 commit e37e98a

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,8 @@ julia> exp(A)
559559
"""
560560
exp(A::StridedMatrix{<:BlasFloat}) = exp!(copy(A))
561561
exp(A::StridedMatrix{<:Union{Integer,Complex{<:Integer}}}) = exp!(float.(A))
562+
exp(A::Adjoint{<:Any,<:AbstractMatrix}) = adjoint(exp(parent(A)))
563+
exp(A::Transpose{<:Any,<:AbstractMatrix}) = transpose(exp(parent(A)))
562564

563565
"""
564566
cis(A::AbstractMatrix)
@@ -762,6 +764,9 @@ function log(A::StridedMatrix)
762764
end
763765
end
764766

767+
log(A::Adjoint{<:Any,<:AbstractMatrix}) = adjoint(log(parent(A)))
768+
log(A::Transpose{<:Any,<:AbstractMatrix}) = transpose(log(parent(A)))
769+
765770
"""
766771
sqrt(A::AbstractMatrix)
767772
@@ -837,6 +842,9 @@ function sqrt(A::StridedMatrix{T}) where {T<:Union{Real,Complex}}
837842
end
838843
end
839844

845+
sqrt(A::Adjoint{<:Any,<:AbstractMatrix}) = adjoint(sqrt(parent(A)))
846+
sqrt(A::Transpose{<:Any,<:AbstractMatrix}) = transpose(sqrt(parent(A)))
847+
840848
function inv(A::StridedMatrix{T}) where T
841849
checksquare(A)
842850
S = typeof((one(T)*zero(T) + one(T)*zero(T))/one(T))

stdlib/LinearAlgebra/test/dense.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,13 @@ end
145145
@testset "Matrix square root" begin
146146
asq = sqrt(a)
147147
@test asq*asq a
148+
@test sqrt(transpose(a))*sqrt(transpose(a)) transpose(a)
149+
@test sqrt(adjoint(a))*sqrt(adjoint(a)) adjoint(a)
148150
asym = a + a' # symmetric indefinite
149151
asymsq = sqrt(asym)
150152
@test asymsq*asymsq asym
153+
@test sqrt(transpose(asym))*sqrt(transpose(asym)) transpose(asym)
154+
@test sqrt(adjoint(asym))*sqrt(adjoint(asym)) adjoint(asym)
151155
if eltype(a) <: Real # real square root
152156
apos = a * a
153157
@test sqrt(apos)^2 apos
@@ -447,6 +451,11 @@ end
447451
183.765138646367 183.765138646366 163.679601723179;
448452
71.797032399996 91.8825693231832 111.968106246371]')
449453
@test exp(A1) eA1
454+
@test exp(adjoint(A1)) adjoint(eA1)
455+
@test exp(transpose(A1)) transpose(eA1)
456+
for f in (sin, cos, sinh, cosh, tanh, tan)
457+
@test f(adjoint(A1)) f(copy(adjoint(A1)))
458+
end
450459

451460
A2 = convert(Matrix{elty},
452461
[29.87942128909879 0.7815750847907159 -2.289519314033932;
@@ -457,20 +466,28 @@ end
457466
-18231880972009252.0 60605228702221920.0 101291842930249760.0;
458467
-30475770808580480.0 101291842930249728.0 169294411240851968.0])
459468
@test exp(A2) eA2
469+
@test exp(adjoint(A2)) adjoint(eA2)
470+
@test exp(transpose(A2)) transpose(eA2)
460471

461472
A3 = convert(Matrix{elty}, [-131 19 18;-390 56 54;-387 57 52])
462473
eA3 = convert(Matrix{elty}, [-1.50964415879218 -5.6325707998812 -4.934938326092;
463474
0.367879439109187 1.47151775849686 1.10363831732856;
464475
0.135335281175235 0.406005843524598 0.541341126763207]')
465476
@test exp(A3) eA3
477+
@test exp(adjoint(A3)) adjoint(eA3)
478+
@test exp(transpose(A3)) transpose(eA3)
466479

467480
A4 = convert(Matrix{elty}, [0.25 0.25; 0 0])
468481
eA4 = convert(Matrix{elty}, [1.2840254166877416 0.2840254166877415; 0 1])
469482
@test exp(A4) eA4
483+
@test exp(adjoint(A4)) adjoint(eA4)
484+
@test exp(transpose(A4)) transpose(eA4)
470485

471486
A5 = convert(Matrix{elty}, [0 0.02; 0 0])
472487
eA5 = convert(Matrix{elty}, [1 0.02; 0 1])
473488
@test exp(A5) eA5
489+
@test exp(adjoint(A5)) adjoint(eA5)
490+
@test exp(transpose(A5)) transpose(eA5)
474491

475492
# Hessenberg
476493
@test hessenberg(A1).H convert(Matrix{elty},
@@ -496,15 +513,23 @@ end
496513
1/4 1/5 1/6 1/7;
497514
1/5 1/6 1/7 1/8])
498515
@test exp(log(A4)) A4
516+
@test exp(log(transpose(A4))) transpose(A4)
517+
@test exp(log(adjoint(A4))) adjoint(A4)
499518

500519
A5 = convert(Matrix{elty}, [1 1 0 1; 0 1 1 0; 0 0 1 1; 1 0 0 1])
501520
@test exp(log(A5)) A5
521+
@test exp(log(transpose(A5))) transpose(A5)
522+
@test exp(log(adjoint(A5))) adjoint(A5)
502523

503524
A6 = convert(Matrix{elty}, [-5 2 0 0 ; 1/2 -7 3 0; 0 1/3 -9 4; 0 0 1/4 -11])
504525
@test exp(log(A6)) A6
526+
@test exp(log(transpose(A6))) transpose(A6)
527+
@test exp(log(adjoint(A6))) adjoint(A6)
505528

506529
A7 = convert(Matrix{elty}, [1 0 0 1e-8; 0 1 0 0; 0 0 1 0; 0 0 0 1])
507530
@test exp(log(A7)) A7
531+
@test exp(log(transpose(A7))) transpose(A7)
532+
@test exp(log(adjoint(A7))) adjoint(A7)
508533
end
509534

510535
@testset "Integer promotion tests" begin

0 commit comments

Comments
 (0)