Skip to content

Commit d40fa57

Browse files
committed
Reland "Reroute Symmetric/Hermitian + Diagonal through triangular"
This backports the following commits: commit 9690961c426ce2640d7db6c89952e69f87873a93 Author: Jishnu Bhattacharya <jishnub.github@gmail.com> Date: Mon Apr 29 21:43:31 2024 +0530 Add upper/lowertriangular functions and use in applytri (#53573) We may use the fact that a `Diagonal` is already triangular to avoid adding a wrapper. Fixes the specific example in https://github.com/JuliaLang/julia/issues/53564, although not the broader issue. This is because it changes the operation from a `UpperTriangular + UpperTriangular` to a `UpperTriangular + Diagonal`, which uses broadcasting. The latter operation may also allow one to define more efficient methods. commit 77821cdddb968eeabf31ccb6b214ccf59a604c68 Author: Jishnu Bhattacharya <jishnub.github@gmail.com> Date: Wed Aug 28 00:53:31 2024 +0530 Remove Diagonal-triangular specialization commit 621fb2e739a04207df63857700aca3562b41b5eb Author: Jishnu Bhattacharya <jishnub.github@gmail.com> Date: Wed Aug 28 00:50:49 2024 +0530 Restrict broadcasting to strided-diag Diagonal commit 58eb2045ddb5dbbfdb759c06239ca54751e73d71 Author: Jishnu Bhattacharya <jishnub.github@gmail.com> Date: Wed Aug 28 00:44:47 2024 +0530 Add tests for partly filled parent commit 5aa6080a580bfbc9453e94a06f3e379e4517b316 Author: Jishnu Bhattacharya <jishnub.github@gmail.com> Date: Tue Aug 27 20:42:07 2024 +0530 Reroute Symmetric/Hermitian + Diagonal through triangular
1 parent f09de94 commit d40fa57

File tree

8 files changed

+326
-21
lines changed

8 files changed

+326
-21
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -250,21 +250,6 @@ end
250250
(+)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag + Db.diag)
251251
(-)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag - Db.diag)
252252

253-
for f in (:+, :-)
254-
@eval function $f(D::Diagonal{<:Number}, S::Symmetric)
255-
return Symmetric($f(D, S.data), sym_uplo(S.uplo))
256-
end
257-
@eval function $f(S::Symmetric, D::Diagonal{<:Number})
258-
return Symmetric($f(S.data, D), sym_uplo(S.uplo))
259-
end
260-
@eval function $f(D::Diagonal{<:Real}, H::Hermitian)
261-
return Hermitian($f(D, H.data), sym_uplo(H.uplo))
262-
end
263-
@eval function $f(H::Hermitian, D::Diagonal{<:Real})
264-
return Hermitian($f(H.data, D), sym_uplo(H.uplo))
265-
end
266-
end
267-
268253
(*)(x::Number, D::Diagonal) = Diagonal(x * D.diag)
269254
(*)(D::Diagonal, x::Number) = Diagonal(D.diag * x)
270255
(/)(D::Diagonal, x::Number) = Diagonal(D.diag / x)
@@ -991,3 +976,6 @@ end
991976
function Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal)
992977
Diagonal(A.diag .* B.diag .+ z.diag)
993978
end
979+
980+
uppertriangular(D::Diagonal) = D
981+
lowertriangular(D::Diagonal) = D

stdlib/LinearAlgebra/src/special.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,25 @@ function (-)(A::UniformScaling, B::Diagonal)
264264
Diagonal(Ref(A) .- B.diag)
265265
end
266266

267+
for f in (:+, :-)
268+
@eval function $f(D::Diagonal{<:Number}, S::Symmetric)
269+
uplo = sym_uplo(S.uplo)
270+
return Symmetric(parentof_applytri($f, Symmetric(D, uplo), S), uplo)
271+
end
272+
@eval function $f(S::Symmetric, D::Diagonal{<:Number})
273+
uplo = sym_uplo(S.uplo)
274+
return Symmetric(parentof_applytri($f, S, Symmetric(D, uplo)), uplo)
275+
end
276+
@eval function $f(D::Diagonal{<:Real}, H::Hermitian)
277+
uplo = sym_uplo(H.uplo)
278+
return Hermitian(parentof_applytri($f, Hermitian(D, uplo), H), uplo)
279+
end
280+
@eval function $f(H::Hermitian, D::Diagonal{<:Real})
281+
uplo = sym_uplo(H.uplo)
282+
return Hermitian(parentof_applytri($f, H, Hermitian(D, uplo)), uplo)
283+
end
284+
end
285+
267286
## Diagonal construction from UniformScaling
268287
Diagonal{T}(s::UniformScaling, m::Integer) where {T} = Diagonal{T}(fill(T(s.λ), m))
269288
Diagonal(s::UniformScaling, m::Integer) = Diagonal{eltype(s)}(s, m)

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,21 +277,21 @@ diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo))
277277

278278
function applytri(f, A::HermOrSym)
279279
if A.uplo == 'U'
280-
f(UpperTriangular(A.data))
280+
f(uppertriangular(A.data))
281281
else
282-
f(LowerTriangular(A.data))
282+
f(lowertriangular(A.data))
283283
end
284284
end
285285

286286
function applytri(f, A::HermOrSym, B::HermOrSym)
287287
if A.uplo == B.uplo == 'U'
288-
f(UpperTriangular(A.data), UpperTriangular(B.data))
288+
f(uppertriangular(A.data), uppertriangular(B.data))
289289
elseif A.uplo == B.uplo == 'L'
290-
f(LowerTriangular(A.data), LowerTriangular(B.data))
290+
f(lowertriangular(A.data), lowertriangular(B.data))
291291
elseif A.uplo == 'U'
292-
f(UpperTriangular(A.data), UpperTriangular(_conjugation(B)(B.data)))
292+
f(uppertriangular(A.data), uppertriangular(_conjugation(B)(B.data)))
293293
else # A.uplo == 'L'
294-
f(UpperTriangular(_conjugation(A)(A.data)), UpperTriangular(B.data))
294+
f(uppertriangular(_conjugation(A)(A.data)), uppertriangular(B.data))
295295
end
296296
end
297297
parentof_applytri(f, args...) = applytri(parent f, args...)

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,14 @@ const UpperOrUnitUpperTriangular{T,S} = Union{UpperTriangular{T,S}, UnitUpperTri
153153
const LowerOrUnitLowerTriangular{T,S} = Union{LowerTriangular{T,S}, UnitLowerTriangular{T,S}}
154154
const UpperOrLowerTriangular{T,S} = Union{UpperOrUnitUpperTriangular{T,S}, LowerOrUnitLowerTriangular{T,S}}
155155

156+
uppertriangular(M) = UpperTriangular(M)
157+
lowertriangular(M) = LowerTriangular(M)
158+
159+
uppertriangular(U::UpperOrUnitUpperTriangular) = U
160+
lowertriangular(U::LowerOrUnitLowerTriangular) = U
161+
162+
Base.dataids(A::UpperOrLowerTriangular) = Base.dataids(A.data)
163+
156164
imag(A::UpperTriangular) = UpperTriangular(imag(A.data))
157165
imag(A::LowerTriangular) = LowerTriangular(imag(A.data))
158166
imag(A::UpperTriangular{<:Any,<:StridedMaybeAdjOrTransMat}) = imag.(A)

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,12 @@ end
12771277
@test c == Diagonal([2,2,2,2])
12781278
end
12791279

1280+
@testset "uppertriangular/lowertriangular" begin
1281+
D = Diagonal([1,2])
1282+
@test LinearAlgebra.uppertriangular(D) === D
1283+
@test LinearAlgebra.lowertriangular(D) === D
1284+
end
1285+
12801286
@testset "mul/div with an adjoint vector" begin
12811287
A = [1.0;;]
12821288
x = [1.0]

stdlib/LinearAlgebra/test/special.jl

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,4 +536,255 @@ end
536536
@test v * S isa Matrix
537537
end
538538

539+
@testset "copyto! between matrix types" begin
540+
dl, d, du = zeros(Int,4), [1:5;], zeros(Int,4)
541+
d_ones = ones(Int,size(du))
542+
543+
@testset "from Diagonal" begin
544+
D = Diagonal(d)
545+
@testset "to Bidiagonal" begin
546+
BU = Bidiagonal(similar(d, BigInt), similar(du, BigInt), :U)
547+
BL = Bidiagonal(similar(d, BigInt), similar(dl, BigInt), :L)
548+
for B in (BL, BU)
549+
copyto!(B, D)
550+
@test B == D
551+
end
552+
553+
@testset "mismatched size" begin
554+
for B in (BU, BL)
555+
B .= 0
556+
copyto!(B, Diagonal(Int[1]))
557+
@test B[1,1] == 1
558+
B[1,1] = 0
559+
@test iszero(B)
560+
end
561+
end
562+
end
563+
@testset "to Tridiagonal" begin
564+
T = Tridiagonal(similar(dl, BigInt), similar(d, BigInt), similar(du, BigInt))
565+
copyto!(T, D)
566+
@test T == D
567+
568+
@testset "mismatched size" begin
569+
T .= 0
570+
copyto!(T, Diagonal([1]))
571+
@test T[1,1] == 1
572+
T[1,1] = 0
573+
@test iszero(T)
574+
end
575+
end
576+
@testset "to SymTridiagonal" begin
577+
for du2 in (similar(du, BigInt), similar(d, BigInt))
578+
S = SymTridiagonal(similar(d), du2)
579+
copyto!(S, D)
580+
@test S == D
581+
end
582+
583+
@testset "mismatched size" begin
584+
S = SymTridiagonal(zero(d), zero(du))
585+
copyto!(S, Diagonal([1]))
586+
@test S[1,1] == 1
587+
S[1,1] = 0
588+
@test iszero(S)
589+
end
590+
end
591+
end
592+
593+
@testset "from Bidiagonal" begin
594+
BU = Bidiagonal(d, du, :U)
595+
BUones = Bidiagonal(d, oneunit.(du), :U)
596+
BL = Bidiagonal(d, dl, :L)
597+
BLones = Bidiagonal(d, oneunit.(dl), :L)
598+
@testset "to Diagonal" begin
599+
D = Diagonal(zero(d))
600+
for B in (BL, BU)
601+
@test copyto!(D, B) == B
602+
D .= 0
603+
end
604+
for B in (BLones, BUones)
605+
errmsg = "cannot copy a Bidiagonal with a non-zero off-diagonal band to a Diagonal"
606+
@test_throws errmsg copyto!(D, B)
607+
@test iszero(D)
608+
end
609+
610+
@testset "mismatched size" begin
611+
for uplo in (:L, :U)
612+
D .= 0
613+
copyto!(D, Bidiagonal(Int[1], Int[], uplo))
614+
@test D[1,1] == 1
615+
D[1,1] = 0
616+
@test iszero(D)
617+
end
618+
end
619+
end
620+
@testset "to Tridiagonal" begin
621+
T = Tridiagonal(similar(dl, BigInt), similar(d, BigInt), similar(du, BigInt))
622+
for B in (BL, BU, BLones, BUones)
623+
copyto!(T, B)
624+
@test T == B
625+
end
626+
627+
@testset "mismatched size" begin
628+
T = Tridiagonal(oneunit.(dl), zero(d), oneunit.(du))
629+
for uplo in (:L, :U)
630+
T .= 0
631+
copyto!(T, Bidiagonal([1], Int[], uplo))
632+
@test T[1,1] == 1
633+
T[1,1] = 0
634+
@test iszero(T)
635+
end
636+
end
637+
end
638+
@testset "to SymTridiagonal" begin
639+
for du2 in (similar(du, BigInt), similar(d, BigInt))
640+
S = SymTridiagonal(similar(d, BigInt), du2)
641+
for B in (BL, BU)
642+
copyto!(S, B)
643+
@test S == B
644+
end
645+
errmsg = "cannot copy a non-symmetric Bidiagonal matrix to a SymTridiagonal"
646+
@test_throws errmsg copyto!(S, BUones)
647+
@test_throws errmsg copyto!(S, BLones)
648+
end
649+
650+
@testset "mismatched size" begin
651+
S = SymTridiagonal(zero(d), zero(du))
652+
for uplo in (:L, :U)
653+
copyto!(S, Bidiagonal([1], Int[], uplo))
654+
@test S[1,1] == 1
655+
S[1,1] = 0
656+
@test iszero(S)
657+
end
658+
end
659+
end
660+
end
661+
662+
@testset "from Tridiagonal" begin
663+
T = Tridiagonal(dl, d, du)
664+
TU = Tridiagonal(dl, d, d_ones)
665+
TL = Tridiagonal(d_ones, d, dl)
666+
@testset "to Diagonal" begin
667+
D = Diagonal(zero(d))
668+
@test copyto!(D, T) == Diagonal(d)
669+
errmsg = "cannot copy a Tridiagonal with a non-zero off-diagonal band to a Diagonal"
670+
D .= 0
671+
@test_throws errmsg copyto!(D, TU)
672+
@test iszero(D)
673+
errmsg = "cannot copy a Tridiagonal with a non-zero off-diagonal band to a Diagonal"
674+
@test_throws errmsg copyto!(D, TL)
675+
@test iszero(D)
676+
677+
@testset "mismatched size" begin
678+
D .= 0
679+
copyto!(D, Tridiagonal(Int[], Int[1], Int[]))
680+
@test D[1,1] == 1
681+
D[1,1] = 0
682+
@test iszero(D)
683+
end
684+
end
685+
@testset "to Bidiagonal" begin
686+
BU = Bidiagonal(zero(d), zero(du), :U)
687+
BL = Bidiagonal(zero(d), zero(du), :L)
688+
@test copyto!(BU, T) == Bidiagonal(d, du, :U)
689+
@test copyto!(BL, T) == Bidiagonal(d, du, :L)
690+
691+
BU .= 0
692+
BL .= 0
693+
errmsg = "cannot copy a Tridiagonal with a non-zero superdiagonal to a Bidiagonal with uplo=:L"
694+
@test_throws errmsg copyto!(BL, TU)
695+
@test iszero(BL)
696+
@test copyto!(BU, TU) == Bidiagonal(d, d_ones, :U)
697+
698+
BU .= 0
699+
BL .= 0
700+
@test copyto!(BL, TL) == Bidiagonal(d, d_ones, :L)
701+
errmsg = "cannot copy a Tridiagonal with a non-zero subdiagonal to a Bidiagonal with uplo=:U"
702+
@test_throws errmsg copyto!(BU, TL)
703+
@test iszero(BU)
704+
705+
@testset "mismatched size" begin
706+
for B in (BU, BL)
707+
B .= 0
708+
copyto!(B, Tridiagonal(Int[], Int[1], Int[]))
709+
@test B[1,1] == 1
710+
B[1,1] = 0
711+
@test iszero(B)
712+
end
713+
end
714+
end
715+
end
716+
717+
@testset "from SymTridiagonal" begin
718+
S2 = SymTridiagonal(d, ones(Int,size(d)))
719+
for S in (SymTridiagonal(d, du), SymTridiagonal(d, zero(d)))
720+
@testset "to Diagonal" begin
721+
D = Diagonal(zero(d))
722+
@test copyto!(D, S) == Diagonal(d)
723+
D .= 0
724+
errmsg = "cannot copy a SymTridiagonal with a non-zero off-diagonal band to a Diagonal"
725+
@test_throws errmsg copyto!(D, S2)
726+
@test iszero(D)
727+
728+
@testset "mismatched size" begin
729+
D .= 0
730+
copyto!(D, SymTridiagonal(Int[1], Int[]))
731+
@test D[1,1] == 1
732+
D[1,1] = 0
733+
@test iszero(D)
734+
end
735+
end
736+
@testset "to Bidiagonal" begin
737+
BU = Bidiagonal(zero(d), zero(du), :U)
738+
BL = Bidiagonal(zero(d), zero(du), :L)
739+
@test copyto!(BU, S) == Bidiagonal(d, du, :U)
740+
@test copyto!(BL, S) == Bidiagonal(d, du, :L)
741+
742+
BU .= 0
743+
BL .= 0
744+
errmsg = "cannot copy a SymTridiagonal with a non-zero off-diagonal band to a Bidiagonal"
745+
@test_throws errmsg copyto!(BU, S2)
746+
@test iszero(BU)
747+
@test_throws errmsg copyto!(BL, S2)
748+
@test iszero(BL)
749+
750+
@testset "mismatched size" begin
751+
for B in (BU, BL)
752+
B .= 0
753+
copyto!(B, SymTridiagonal(Int[1], Int[]))
754+
@test B[1,1] == 1
755+
B[1,1] = 0
756+
@test iszero(B)
757+
end
758+
end
759+
end
760+
end
761+
end
762+
end
763+
764+
@testset "BandIndex indexing" begin
765+
for D in (Diagonal(1:3), Bidiagonal(1:3, 2:3, :U), Bidiagonal(1:3, 2:3, :L),
766+
Tridiagonal(2:3, 1:3, 1:2), SymTridiagonal(1:3, 2:3))
767+
M = Matrix(D)
768+
for band in -size(D,1)+1:size(D,1)-1
769+
for idx in 1:size(D,1)-abs(band)
770+
@test D[BandIndex(band, idx)] == M[BandIndex(band, idx)]
771+
end
772+
end
773+
@test_throws BoundsError D[BandIndex(size(D,1),1)]
774+
end
775+
end
776+
777+
@testset "Partly filled Hermitian and Diagonal algebra" begin
778+
D = Diagonal([1,2])
779+
for S in (Symmetric, Hermitian), uplo in (:U, :L)
780+
M = Matrix{BigInt}(undef, 2, 2)
781+
M[1,1] = M[2,2] = M[1+(uplo == :L), 1 + (uplo == :U)] = 3
782+
H = S(M, uplo)
783+
HM = Matrix(H)
784+
@test H + D == D + H == HM + D
785+
@test H - D == HM - D
786+
@test D - H == D - HM
787+
end
788+
end
789+
539790
end # module TestSpecial

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,31 @@ end
507507
@test Su - Sl == -(Sl - Su) == MSu - MSl
508508
end
509509
end
510+
@testset "non-strided" begin
511+
@testset "diagonal" begin
512+
for ST1 in (Symmetric, Hermitian), uplo1 in (:L, :U)
513+
m = ST1(Matrix{BigFloat}(undef,2,2), uplo1)
514+
m.data[1,1] = 1
515+
m.data[2,2] = 3
516+
m.data[1+(uplo1==:L), 1+(uplo1==:U)] = 2
517+
A = Array(m)
518+
for ST2 in (Symmetric, Hermitian), uplo2 in (:L, :U)
519+
id = ST2(I(2), uplo2)
520+
@test m + id == id + m == A + id
521+
end
522+
end
523+
end
524+
@testset "unit triangular" begin
525+
for ST1 in (Symmetric, Hermitian), uplo1 in (:L, :U)
526+
H1 = ST1(UnitUpperTriangular(big.(rand(Int8,4,4))), uplo1)
527+
M1 = Matrix(H1)
528+
for ST2 in (Symmetric, Hermitian), uplo2 in (:L, :U)
529+
H2 = ST2(UnitUpperTriangular(big.(rand(Int8,4,4))), uplo2)
530+
@test H1 + H2 == M1 + Matrix(H2)
531+
end
532+
end
533+
end
534+
end
510535
end
511536

512537
# bug identified in PR #52318: dot products of quaternionic Hermitian matrices,

0 commit comments

Comments
 (0)