Skip to content

Commit 4ab4722

Browse files
committed
Reland "Reroute Symmetric/Hermitian + Diagonal through triangular"
This backports the following commits: commit 9690961c426ce2640d7db6c89952e69f87873a93 Author: Jishnu Bhattacharya <jishnub.github@gmail.com> Date: Mon Apr 29 21:43:31 2024 +0530 Add upper/lowertriangular functions and use in applytri (#53573) We may use the fact that a `Diagonal` is already triangular to avoid adding a wrapper. Fixes the specific example in https://github.com/JuliaLang/julia/issues/53564, although not the broader issue. This is because it changes the operation from a `UpperTriangular + UpperTriangular` to a `UpperTriangular + Diagonal`, which uses broadcasting. The latter operation may also allow one to define more efficient methods. commit 77821cdddb968eeabf31ccb6b214ccf59a604c68 Author: Jishnu Bhattacharya <jishnub.github@gmail.com> Date: Wed Aug 28 00:53:31 2024 +0530 Remove Diagonal-triangular specialization commit 621fb2e739a04207df63857700aca3562b41b5eb Author: Jishnu Bhattacharya <jishnub.github@gmail.com> Date: Wed Aug 28 00:50:49 2024 +0530 Restrict broadcasting to strided-diag Diagonal commit 58eb2045ddb5dbbfdb759c06239ca54751e73d71 Author: Jishnu Bhattacharya <jishnub.github@gmail.com> Date: Wed Aug 28 00:44:47 2024 +0530 Add tests for partly filled parent commit 5aa6080a580bfbc9453e94a06f3e379e4517b316 Author: Jishnu Bhattacharya <jishnub.github@gmail.com> Date: Tue Aug 27 20:42:07 2024 +0530 Reroute Symmetric/Hermitian + Diagonal through triangular
1 parent f09de94 commit 4ab4722

File tree

8 files changed

+86
-21
lines changed

8 files changed

+86
-21
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -250,21 +250,6 @@ end
250250
(+)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag + Db.diag)
251251
(-)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag - Db.diag)
252252

253-
for f in (:+, :-)
254-
@eval function $f(D::Diagonal{<:Number}, S::Symmetric)
255-
return Symmetric($f(D, S.data), sym_uplo(S.uplo))
256-
end
257-
@eval function $f(S::Symmetric, D::Diagonal{<:Number})
258-
return Symmetric($f(S.data, D), sym_uplo(S.uplo))
259-
end
260-
@eval function $f(D::Diagonal{<:Real}, H::Hermitian)
261-
return Hermitian($f(D, H.data), sym_uplo(H.uplo))
262-
end
263-
@eval function $f(H::Hermitian, D::Diagonal{<:Real})
264-
return Hermitian($f(H.data, D), sym_uplo(H.uplo))
265-
end
266-
end
267-
268253
(*)(x::Number, D::Diagonal) = Diagonal(x * D.diag)
269254
(*)(D::Diagonal, x::Number) = Diagonal(D.diag * x)
270255
(/)(D::Diagonal, x::Number) = Diagonal(D.diag / x)
@@ -991,3 +976,6 @@ end
991976
function Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal)
992977
Diagonal(A.diag .* B.diag .+ z.diag)
993978
end
979+
980+
uppertriangular(D::Diagonal) = D
981+
lowertriangular(D::Diagonal) = D

stdlib/LinearAlgebra/src/special.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,25 @@ function (-)(A::UniformScaling, B::Diagonal)
264264
Diagonal(Ref(A) .- B.diag)
265265
end
266266

267+
for f in (:+, :-)
268+
@eval function $f(D::Diagonal{<:Number}, S::Symmetric)
269+
uplo = sym_uplo(S.uplo)
270+
return Symmetric(parentof_applytri($f, Symmetric(D, uplo), S), uplo)
271+
end
272+
@eval function $f(S::Symmetric, D::Diagonal{<:Number})
273+
uplo = sym_uplo(S.uplo)
274+
return Symmetric(parentof_applytri($f, S, Symmetric(D, uplo)), uplo)
275+
end
276+
@eval function $f(D::Diagonal{<:Real}, H::Hermitian)
277+
uplo = sym_uplo(H.uplo)
278+
return Hermitian(parentof_applytri($f, Hermitian(D, uplo), H), uplo)
279+
end
280+
@eval function $f(H::Hermitian, D::Diagonal{<:Real})
281+
uplo = sym_uplo(H.uplo)
282+
return Hermitian(parentof_applytri($f, H, Hermitian(D, uplo)), uplo)
283+
end
284+
end
285+
267286
## Diagonal construction from UniformScaling
268287
Diagonal{T}(s::UniformScaling, m::Integer) where {T} = Diagonal{T}(fill(T(s.λ), m))
269288
Diagonal(s::UniformScaling, m::Integer) = Diagonal{eltype(s)}(s, m)

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,21 +277,21 @@ diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo))
277277

278278
function applytri(f, A::HermOrSym)
279279
if A.uplo == 'U'
280-
f(UpperTriangular(A.data))
280+
f(uppertriangular(A.data))
281281
else
282-
f(LowerTriangular(A.data))
282+
f(lowertriangular(A.data))
283283
end
284284
end
285285

286286
function applytri(f, A::HermOrSym, B::HermOrSym)
287287
if A.uplo == B.uplo == 'U'
288-
f(UpperTriangular(A.data), UpperTriangular(B.data))
288+
f(uppertriangular(A.data), uppertriangular(B.data))
289289
elseif A.uplo == B.uplo == 'L'
290-
f(LowerTriangular(A.data), LowerTriangular(B.data))
290+
f(lowertriangular(A.data), lowertriangular(B.data))
291291
elseif A.uplo == 'U'
292-
f(UpperTriangular(A.data), UpperTriangular(_conjugation(B)(B.data)))
292+
f(uppertriangular(A.data), uppertriangular(_conjugation(B)(B.data)))
293293
else # A.uplo == 'L'
294-
f(UpperTriangular(_conjugation(A)(A.data)), UpperTriangular(B.data))
294+
f(uppertriangular(_conjugation(A)(A.data)), uppertriangular(B.data))
295295
end
296296
end
297297
parentof_applytri(f, args...) = applytri(parent f, args...)

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ const UpperOrUnitUpperTriangular{T,S} = Union{UpperTriangular{T,S}, UnitUpperTri
153153
const LowerOrUnitLowerTriangular{T,S} = Union{LowerTriangular{T,S}, UnitLowerTriangular{T,S}}
154154
const UpperOrLowerTriangular{T,S} = Union{UpperOrUnitUpperTriangular{T,S}, LowerOrUnitLowerTriangular{T,S}}
155155

156+
uppertriangular(M) = UpperTriangular(M)
157+
lowertriangular(M) = LowerTriangular(M)
158+
159+
uppertriangular(U::UpperOrUnitUpperTriangular) = U
160+
lowertriangular(U::LowerOrUnitLowerTriangular) = U
161+
156162
imag(A::UpperTriangular) = UpperTriangular(imag(A.data))
157163
imag(A::LowerTriangular) = LowerTriangular(imag(A.data))
158164
imag(A::UpperTriangular{<:Any,<:StridedMaybeAdjOrTransMat}) = imag.(A)

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,12 @@ end
12771277
@test c == Diagonal([2,2,2,2])
12781278
end
12791279

1280+
@testset "uppertriangular/lowertriangular" begin
1281+
D = Diagonal([1,2])
1282+
@test LinearAlgebra.uppertriangular(D) === D
1283+
@test LinearAlgebra.lowertriangular(D) === D
1284+
end
1285+
12801286
@testset "mul/div with an adjoint vector" begin
12811287
A = [1.0;;]
12821288
x = [1.0]

stdlib/LinearAlgebra/test/special.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,4 +536,17 @@ end
536536
@test v * S isa Matrix
537537
end
538538

539+
@testset "Partly filled Hermitian and Diagonal algebra" begin
540+
D = Diagonal([1,2])
541+
for S in (Symmetric, Hermitian), uplo in (:U, :L)
542+
M = Matrix{BigInt}(undef, 2, 2)
543+
M[1,1] = M[2,2] = M[1+(uplo == :L), 1 + (uplo == :U)] = 3
544+
H = S(M, uplo)
545+
HM = Matrix(H)
546+
@test H + D == D + H == HM + D
547+
@test H - D == HM - D
548+
@test D - H == D - HM
549+
end
550+
end
551+
539552
end # module TestSpecial

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,31 @@ end
507507
@test Su - Sl == -(Sl - Su) == MSu - MSl
508508
end
509509
end
510+
@testset "non-strided" begin
511+
@testset "diagonal" begin
512+
for ST1 in (Symmetric, Hermitian), uplo1 in (:L, :U)
513+
m = ST1(Matrix{BigFloat}(undef,2,2), uplo1)
514+
m.data[1,1] = 1
515+
m.data[2,2] = 3
516+
m.data[1+(uplo1==:L), 1+(uplo1==:U)] = 2
517+
A = Array(m)
518+
for ST2 in (Symmetric, Hermitian), uplo2 in (:L, :U)
519+
id = ST2(I(2), uplo2)
520+
@test m + id == id + m == A + id
521+
end
522+
end
523+
end
524+
@testset "unit triangular" begin
525+
for ST1 in (Symmetric, Hermitian), uplo1 in (:L, :U)
526+
H1 = ST1(UnitUpperTriangular(big.(rand(Int8,4,4))), uplo1)
527+
M1 = Matrix(H1)
528+
for ST2 in (Symmetric, Hermitian), uplo2 in (:L, :U)
529+
H2 = ST2(UnitUpperTriangular(big.(rand(Int8,4,4))), uplo2)
530+
@test H1 + H2 == M1 + Matrix(H2)
531+
end
532+
end
533+
end
534+
end
510535
end
511536

512537
# bug identified in PR #52318: dot products of quaternionic Hermitian matrices,

stdlib/LinearAlgebra/test/triangular.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,14 @@ end
996996
end
997997
end
998998

999+
@testset "uppertriangular/lowertriangular" begin
1000+
M = rand(2,2)
1001+
@test LinearAlgebra.uppertriangular(M) === UpperTriangular(M)
1002+
@test LinearAlgebra.lowertriangular(M) === LowerTriangular(M)
1003+
@test LinearAlgebra.uppertriangular(UnitUpperTriangular(M)) === UnitUpperTriangular(M)
1004+
@test LinearAlgebra.lowertriangular(UnitLowerTriangular(M)) === UnitLowerTriangular(M)
1005+
end
1006+
9991007
@testset "arithmetic with partly uninitialized matrices" begin
10001008
@testset "$(typeof(A))" for A in (Matrix{BigFloat}(undef,2,2), Matrix{Complex{BigFloat}}(undef,2,2)')
10011009
A[2,1] = eltype(A) <: Complex ? 4 + 3im : 4

0 commit comments

Comments
 (0)