Skip to content

Commit bd5f71f

Browse files
authored
Out-of-place copytri in dense trig functions (#1224)
With this, the trig functions work for `SMatrix`es that don't support `setindex!`. ```julia julia> S = SMatrix{4,4}(Symmetric(rand(4,4))) 4×4 SMatrix{4, 4, Float64, 16} with indices SOneTo(4)×SOneTo(4): 0.470973 0.212605 0.0227701 0.744485 0.212605 0.9227 0.022621 0.0139553 0.0227701 0.022621 0.0568546 0.767089 0.744485 0.0139553 0.767089 0.186663 julia> sin(S) 4×4 SMatrix{4, 4, Float64, 16} with indices SOneTo(4)×SOneTo(4): 0.343589 0.140232 -0.044186 0.566605 0.140232 0.781448 -0.00290253 -0.0305854 -0.044186 -0.00290253 0.027389 0.623148 0.566605 -0.0305854 0.623148 0.0739855 ``` There are some type-inference issues at present, as `applydiagonal` returns a mutable array, but these may be addressed separately.
1 parent eb7daee commit bd5f71f

File tree

2 files changed

+77
-20
lines changed

2 files changed

+77
-20
lines changed

src/dense.jl

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,22 @@ cis(A::AbstractMatrix{<:Base.HWNumber}) = exp_maybe_inplace(float.(im .* A))
682682
exp_maybe_inplace(A::StridedMatrix{<:Union{ComplexF32, ComplexF64}}) = exp!(A)
683683
exp_maybe_inplace(A) = exp(A)
684684

685+
function copytri_maybe_inplace(A::StridedMatrix, uplo, conjugate::Bool=false, diag::Bool=false)
686+
copytri!(A, uplo, conjugate, diag)
687+
end
688+
function copytri_maybe_inplace(A, uplo, conjugate::Bool=false, diag::Bool=false)
689+
k = Int(diag)
690+
if uplo == 'U'
691+
B = triu(A, 1-k)
692+
triu(A, k) + (conjugate ? copy(adjoint(B)) : copy(transpose(B)))
693+
elseif uplo == 'L'
694+
B = tril(A, k-1)
695+
tril(A, -k) + (conjugate ? copy(adjoint(B)) : copy(transpose(B)))
696+
else
697+
throw(ArgumentError(lazy"uplo argument must be 'U' (upper) or 'L' (lower), got $uplo"))
698+
end
699+
end
700+
685701
"""
686702
^(b::Number, A::AbstractMatrix)
687703
@@ -894,7 +910,8 @@ function log(A::AbstractMatrix)
894910
return applydiagonal(log, A)
895911
elseif ishermitian(A)
896912
logHermA = log(Hermitian(A))
897-
return ishermitian(logHermA) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
913+
PH = parent(logHermA)
914+
return ishermitian(logHermA) ? copytri_maybe_inplace(PH, 'U', true) : PH
898915
elseif istriu(A)
899916
return triu!(parent(log(UpperTriangular(A))))
900917
elseif isreal(A)
@@ -977,7 +994,8 @@ function sqrt(A::AbstractMatrix{T}) where {T<:Union{Real,Complex}}
977994
return applydiagonal(sqrt, A)
978995
elseif ishermitian(A)
979996
sqrtHermA = sqrt(Hermitian(A))
980-
return ishermitian(sqrtHermA) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
997+
PS = parent(sqrtHermA)
998+
return ishermitian(sqrtHermA) ? copytri_maybe_inplace(PS, 'U', true) : PS
981999
elseif istriu(A)
9821000
return triu!(parent(sqrt(UpperTriangular(A))))
9831001
elseif isreal(A)
@@ -1100,7 +1118,8 @@ function cos(A::AbstractMatrix{<:Real})
11001118
if isdiag(A)
11011119
return applydiagonal(cos, A)
11021120
elseif issymmetric(A)
1103-
return copytri!(parent(cos(Symmetric(A))), 'U')
1121+
P = parent(cos(Symmetric(A)))
1122+
return copytri_maybe_inplace(P, 'U')
11041123
end
11051124
M = im .* float.(A)
11061125
return real(exp_maybe_inplace(M))
@@ -1109,7 +1128,8 @@ function cos(A::AbstractMatrix{<:Complex})
11091128
if isdiag(A)
11101129
return applydiagonal(cos, A)
11111130
elseif ishermitian(A)
1112-
return copytri!(parent(cos(Hermitian(A))), 'U', true)
1131+
P = parent(cos(Hermitian(A)))
1132+
return copytri_maybe_inplace(P, 'U', true)
11131133
end
11141134
M = im .* float.(A)
11151135
N = -M
@@ -1140,7 +1160,8 @@ function sin(A::AbstractMatrix{<:Real})
11401160
if isdiag(A)
11411161
return applydiagonal(sin, A)
11421162
elseif issymmetric(A)
1143-
return copytri!(parent(sin(Symmetric(A))), 'U')
1163+
P = parent(sin(Symmetric(A)))
1164+
return copytri_maybe_inplace(P, 'U')
11441165
end
11451166
M = im .* float.(A)
11461167
return imag(exp_maybe_inplace(M))
@@ -1149,7 +1170,8 @@ function sin(A::AbstractMatrix{<:Complex})
11491170
if isdiag(A)
11501171
return applydiagonal(sin, A)
11511172
elseif ishermitian(A)
1152-
return copytri!(parent(sin(Hermitian(A))), 'U', true)
1173+
P = parent(sin(Hermitian(A)))
1174+
return copytri_maybe_inplace(P, 'U', true)
11531175
end
11541176
M = im .* float.(A)
11551177
Mneg = -M
@@ -1183,8 +1205,10 @@ julia> C
11831205
function sincos(A::AbstractMatrix{<:Real})
11841206
if issymmetric(A)
11851207
symsinA, symcosA = sincos(Symmetric(A))
1186-
sinA = copytri!(parent(symsinA), 'U')
1187-
cosA = copytri!(parent(symcosA), 'U')
1208+
Psin = parent(symsinA)
1209+
Pcos = parent(symcosA)
1210+
sinA = copytri_maybe_inplace(Psin, 'U')
1211+
cosA = copytri_maybe_inplace(Pcos, 'U')
11881212
return sinA, cosA
11891213
end
11901214
M = im .* float.(A)
@@ -1194,8 +1218,10 @@ end
11941218
function sincos(A::AbstractMatrix{<:Complex})
11951219
if ishermitian(A)
11961220
hermsinA, hermcosA = sincos(Hermitian(A))
1197-
sinA = copytri!(parent(hermsinA), 'U', true)
1198-
cosA = copytri!(parent(hermcosA), 'U', true)
1221+
Psin = parent(hermsinA)
1222+
Pcos = parent(hermcosA)
1223+
sinA = copytri_maybe_inplace(Psin, 'U', true)
1224+
cosA = copytri_maybe_inplace(Pcos, 'U', true)
11991225
return sinA, cosA
12001226
end
12011227
M = im .* float.(A)
@@ -1239,7 +1265,8 @@ function tan(A::AbstractMatrix)
12391265
if isdiag(A)
12401266
return applydiagonal(tan, A)
12411267
elseif ishermitian(A)
1242-
return copytri!(parent(tan(Hermitian(A))), 'U', true)
1268+
P = parent(tan(Hermitian(A)))
1269+
return copytri_maybe_inplace(P, 'U', true)
12431270
end
12441271
S, C = sincos(A)
12451272
S /= C
@@ -1255,7 +1282,8 @@ function cosh(A::AbstractMatrix)
12551282
if isdiag(A)
12561283
return applydiagonal(cosh, A)
12571284
elseif ishermitian(A)
1258-
return copytri!(parent(cosh(Hermitian(A))), 'U', true)
1285+
P = parent(cosh(Hermitian(A)))
1286+
return copytri_maybe_inplace(P, 'U', true)
12591287
end
12601288
X = exp(A)
12611289
negA = @. float(-A)
@@ -1272,7 +1300,8 @@ function sinh(A::AbstractMatrix)
12721300
if isdiag(A)
12731301
return applydiagonal(sinh, A)
12741302
elseif ishermitian(A)
1275-
return copytri!(parent(sinh(Hermitian(A))), 'U', true)
1303+
P = parent(sinh(Hermitian(A)))
1304+
return copytri_maybe_inplace(P, 'U', true)
12761305
end
12771306
X = exp(A)
12781307
negA = @. float(-A)
@@ -1289,7 +1318,8 @@ function tanh(A::AbstractMatrix)
12891318
if isdiag(A)
12901319
return applydiagonal(tanh, A)
12911320
elseif ishermitian(A)
1292-
return copytri!(parent(tanh(Hermitian(A))), 'U', true)
1321+
P = parent(tanh(Hermitian(A)))
1322+
return copytri_maybe_inplace(P, 'U', true)
12931323
end
12941324
X = exp(A)
12951325
negA = @. float(-A)
@@ -1332,7 +1362,8 @@ function acos(A::AbstractMatrix)
13321362
return applydiagonal(acos, A)
13331363
elseif ishermitian(A)
13341364
acosHermA = acos(Hermitian(A))
1335-
return isa(acosHermA, Hermitian) ? copytri!(parent(acosHermA), 'U', true) : parent(acosHermA)
1365+
P = parent(acosHermA)
1366+
return isa(acosHermA, Hermitian) ? copytri_maybe_inplace(P, 'U', true) : P
13361367
end
13371368
SchurF = Schur{Complex}(schur(A))
13381369
U = UpperTriangular(SchurF.T)
@@ -1365,7 +1396,8 @@ function asin(A::AbstractMatrix)
13651396
return applydiagonal(asin, A)
13661397
elseif ishermitian(A)
13671398
asinHermA = asin(Hermitian(A))
1368-
return isa(asinHermA, Hermitian) ? copytri!(parent(asinHermA), 'U', true) : parent(asinHermA)
1399+
P = parent(asinHermA)
1400+
return isa(asinHermA, Hermitian) ? copytri_maybe_inplace(P, 'U', true) : P
13691401
end
13701402
SchurF = Schur{Complex}(schur(A))
13711403
U = UpperTriangular(SchurF.T)
@@ -1397,7 +1429,8 @@ function atan(A::AbstractMatrix)
13971429
if isdiag(A)
13981430
return applydiagonal(atan, A)
13991431
elseif ishermitian(A)
1400-
return copytri!(parent(atan(Hermitian(A))), 'U', true)
1432+
P = parent(atan(Hermitian(A)))
1433+
return copytri_maybe_inplace(P, 'U', true)
14011434
end
14021435
SchurF = Schur{Complex}(schur(A))
14031436
U = im * UpperTriangular(SchurF.T)
@@ -1418,7 +1451,8 @@ function acosh(A::AbstractMatrix)
14181451
return applydiagonal(acosh, A)
14191452
elseif ishermitian(A)
14201453
acoshHermA = acosh(Hermitian(A))
1421-
return isa(acoshHermA, Hermitian) ? copytri!(parent(acoshHermA), 'U', true) : parent(acoshHermA)
1454+
P = parent(acoshHermA)
1455+
return isa(acoshHermA, Hermitian) ? copytri_maybe_inplace(P, 'U', true) : P
14221456
end
14231457
SchurF = Schur{Complex}(schur(A))
14241458
U = UpperTriangular(SchurF.T)
@@ -1438,7 +1472,8 @@ function asinh(A::AbstractMatrix)
14381472
if isdiag(A)
14391473
return applydiagonal(asinh, A)
14401474
elseif ishermitian(A)
1441-
return copytri!(parent(asinh(Hermitian(A))), 'U', true)
1475+
P = parent(asinh(Hermitian(A)))
1476+
return copytri_maybe_inplace(P, 'U', true)
14421477
end
14431478
SchurF = Schur{Complex}(schur(A))
14441479
U = UpperTriangular(SchurF.T)
@@ -1458,7 +1493,8 @@ function atanh(A::AbstractMatrix)
14581493
if isdiag(A)
14591494
return applydiagonal(atanh, A)
14601495
elseif ishermitian(A)
1461-
return copytri!(parent(atanh(Hermitian(A))), 'U', true)
1496+
P = parent(atanh(Hermitian(A)))
1497+
return copytri_maybe_inplace(P, 'U', true)
14621498
end
14631499
SchurF = Schur{Complex}(schur(A))
14641500
U = UpperTriangular(SchurF.T)

test/dense.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,4 +1377,25 @@ end
13771377
@test_throws "matrix is singular; factorization failed" inv(A)
13781378
end
13791379

1380+
@testset "copytri_maybe_inplace" begin
1381+
R = reshape(1:16,4,4)
1382+
for conjugate in [true, false], diag in [true, false]
1383+
@test LinearAlgebra.copytri_maybe_inplace(R, 'U', conjugate, diag) == Symmetric(R, :U)
1384+
@test LinearAlgebra.copytri_maybe_inplace(R, 'L', conjugate, diag) == Symmetric(R, :L)
1385+
end
1386+
1387+
Rc = reshape(StepRangeLen(1+2im, 3 + 4im, 16), 4, 4)
1388+
Ac = Array(Rc)
1389+
@testset for conjugate in [true, false], diag in [true, false]
1390+
tR = LinearAlgebra.copytri_maybe_inplace(Rc, 'U', conjugate, diag)
1391+
tA = LinearAlgebra.copytri_maybe_inplace(copy(Ac), 'U', conjugate, diag)
1392+
@test tR == tA
1393+
tR = LinearAlgebra.copytri_maybe_inplace(Rc, 'L', conjugate, diag)
1394+
tA = LinearAlgebra.copytri_maybe_inplace(copy(Ac), 'L', conjugate, diag)
1395+
@test tR == tA
1396+
end
1397+
1398+
@test_throws ArgumentError LinearAlgebra.copytri_maybe_inplace(Rc, 'X')
1399+
end
1400+
13801401
end # module TestDense

0 commit comments

Comments
 (0)