Skip to content

Commit 89cae45

Browse files
authored
Optimized arithmetic methods for strided triangular matrices (#52571)
This uses broadcasting for operations like `A::UpperTriangular + B::UpperTriangular` in case the parents are `StridedMatrix`es. Looping only over the triangular part is usually faster for large matrices, where presumably memory is the bottleneck. Some performance comparisons, using ```julia julia> U = UpperTriangular(rand(1000,1000)); julia> U1 = UnitUpperTriangular(rand(size(U)...)); ``` | Operation | master | PR | | --------------- | ---------- | ----- | |`-U` |`1.011 ms (3 allocations: 7.63 MiB)` |`559.680 μs (3 allocations: 7.63 MiB)` | |`U + U`/`U - U` |`971.740 μs (3 allocations: 7.63 MiB)` | `560.063 μs (3 allocations: 7.63 MiB)` | |`U + U1`/`U - U1` |`3.014 ms (9 allocations: 22.89 MiB)` | `944.772 μs (3 allocations: 7.63 MiB)` | |`U1 + U1` |`4.509 ms (12 allocations: 30.52 MiB)` | `1.687 ms (3 allocations: 7.63 MiB)` | |`U1 - U1` |`3.357 ms (9 allocations: 22.89 MiB)` | `1.763 ms (3 allocations: 7.63 MiB)` | I've retained the existing methods as fallback, in case there's current code that works without broadcasting.
1 parent fe0db7d commit 89cae45

File tree

2 files changed

+107
-14
lines changed

2 files changed

+107
-14
lines changed

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ for t in (:LowerTriangular, :UnitLowerTriangular, :UpperTriangular, :UnitUpperTr
4848

4949
real(A::$t{<:Real}) = A
5050
real(A::$t{<:Complex}) = (B = real(A.data); $t(B))
51+
real(A::$t{<:Complex, <:StridedMaybeAdjOrTransMat}) = $t(real.(A))
5152
end
5253
end
5354

@@ -156,8 +157,26 @@ const UpperOrLowerTriangular{T,S} = Union{UpperOrUnitUpperTriangular{T,S}, Lower
156157

157158
imag(A::UpperTriangular) = UpperTriangular(imag(A.data))
158159
imag(A::LowerTriangular) = LowerTriangular(imag(A.data))
159-
imag(A::UnitLowerTriangular) = LowerTriangular(tril!(imag(A.data),-1))
160-
imag(A::UnitUpperTriangular) = UpperTriangular(triu!(imag(A.data),1))
160+
imag(A::UpperTriangular{<:Any,<:StridedMaybeAdjOrTransMat}) = imag.(A)
161+
imag(A::LowerTriangular{<:Any,<:StridedMaybeAdjOrTransMat}) = imag.(A)
162+
function imag(A::UnitLowerTriangular)
163+
L = LowerTriangular(A.data)
164+
Lim = similar(L) # must be mutable to set diagonals to zero
165+
Lim .= imag.(L)
166+
for i in 1:size(Lim,1)
167+
Lim[i,i] = zero(Lim[i,i])
168+
end
169+
return Lim
170+
end
171+
function imag(A::UnitUpperTriangular)
172+
U = UpperTriangular(A.data)
173+
Uim = similar(U) # must be mutable to set diagonals to zero
174+
Uim .= imag.(U)
175+
for i in 1:size(Uim,1)
176+
Uim[i,i] = zero(Uim[i,i])
177+
end
178+
return Uim
179+
end
161180

162181
Array(A::AbstractTriangular) = Matrix(A)
163182
parent(A::UpperOrLowerTriangular) = A.data
@@ -481,6 +500,11 @@ function -(A::UnitUpperTriangular)
481500
UpperTriangular(Anew)
482501
end
483502

503+
# use broadcasting if the parents are strided, where we loop only over the triangular part
504+
for TM in (:LowerTriangular, :UpperTriangular)
505+
@eval -(A::$TM{<:Any, <:StridedMaybeAdjOrTransMat}) = broadcast(-, A)
506+
end
507+
484508
tr(A::LowerTriangular) = tr(A.data)
485509
tr(A::UnitLowerTriangular) = size(A, 1) * oneunit(eltype(A))
486510
tr(A::UpperTriangular) = tr(A.data)
@@ -719,6 +743,16 @@ fillstored!(A::UnitUpperTriangular, x) = (fillband!(A.data, x, 1, size(A,2)-1);
719743
-(A::UnitLowerTriangular, B::UnitLowerTriangular) = LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
720744
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B)
721745

746+
# use broadcasting if the parents are strided, where we loop only over the triangular part
747+
for op in (:+, :-)
748+
for TM1 in (:LowerTriangular, :UnitLowerTriangular), TM2 in (:LowerTriangular, :UnitLowerTriangular)
749+
@eval $op(A::$TM1{<:Any, <:StridedMaybeAdjOrTransMat}, B::$TM2{<:Any, <:StridedMaybeAdjOrTransMat}) = broadcast($op, A, B)
750+
end
751+
for TM1 in (:UpperTriangular, :UnitUpperTriangular), TM2 in (:UpperTriangular, :UnitUpperTriangular)
752+
@eval $op(A::$TM1{<:Any, <:StridedMaybeAdjOrTransMat}, B::$TM2{<:Any, <:StridedMaybeAdjOrTransMat}) = broadcast($op, A, B)
753+
end
754+
end
755+
722756
######################
723757
# BlasFloat routines #
724758
######################
@@ -918,47 +952,52 @@ end
918952

919953
for (t, unitt) in ((UpperTriangular, UnitUpperTriangular),
920954
(LowerTriangular, UnitLowerTriangular))
955+
tstrided = t{<:Any, <:StridedMaybeAdjOrTransMat}
921956
@eval begin
922957
(*)(A::$t, x::Number) = $t(A.data*x)
958+
(*)(A::$tstrided, x::Number) = A .* x
923959

924960
function (*)(A::$unitt, x::Number)
925-
B = A.data*x
961+
B = $t(A.data)*x
926962
for i = 1:size(A, 1)
927-
B[i,i] = x
963+
B.data[i,i] = x
928964
end
929-
$t(B)
965+
return B
930966
end
931967

932968
(*)(x::Number, A::$t) = $t(x*A.data)
969+
(*)(x::Number, A::$tstrided) = x .* A
933970

934971
function (*)(x::Number, A::$unitt)
935-
B = x*A.data
972+
B = x*$t(A.data)
936973
for i = 1:size(A, 1)
937-
B[i,i] = x
974+
B.data[i,i] = x
938975
end
939-
$t(B)
976+
return B
940977
end
941978

942979
(/)(A::$t, x::Number) = $t(A.data/x)
980+
(/)(A::$tstrided, x::Number) = A ./ x
943981

944982
function (/)(A::$unitt, x::Number)
945-
B = A.data/x
983+
B = $t(A.data)/x
946984
invx = inv(x)
947985
for i = 1:size(A, 1)
948-
B[i,i] = invx
986+
B.data[i,i] = invx
949987
end
950-
$t(B)
988+
return B
951989
end
952990

953991
(\)(x::Number, A::$t) = $t(x\A.data)
992+
(\)(x::Number, A::$tstrided) = x .\ A
954993

955994
function (\)(x::Number, A::$unitt)
956-
B = x\A.data
995+
B = x\$t(A.data)
957996
invx = inv(x)
958997
for i = 1:size(A, 1)
959-
B[i,i] = invx
998+
B.data[i,i] = invx
960999
end
961-
$t(B)
1000+
return B
9621001
end
9631002
end
9641003
end

stdlib/LinearAlgebra/test/triangular.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,23 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo
526526
end
527527
end
528528

529+
@testset "non-strided arithmetic" begin
530+
for (T,T1) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
531+
U = T(reshape(1:16, 4, 4))
532+
M = Matrix(U)
533+
@test -U == -M
534+
U1 = T1(reshape(1:16, 4, 4))
535+
M1 = Matrix(U1)
536+
@test -U1 == -M1
537+
for op in (+, -)
538+
for (A, MA) in ((U, M), (U1, M1)), (B, MB) in ((U, M), (U1, M1))
539+
@test op(A, B) == op(MA, MB)
540+
end
541+
end
542+
@test imag(U) == zero(U)
543+
end
544+
end
545+
529546
# Matrix square root
530547
Atn = UpperTriangular([-1 1 2; 0 -2 2; 0 0 -3])
531548
Atp = UpperTriangular([1 1 2; 0 2 2; 0 0 3])
@@ -894,6 +911,11 @@ end
894911
U = UT(F)
895912
@test -U == -Array(U)
896913
end
914+
915+
F = FillArrays.Fill(3im, (4,4))
916+
for U in (UnitUpperTriangular(F), UnitLowerTriangular(F))
917+
@test imag(F) == imag(collect(F))
918+
end
897919
end
898920

899921
@testset "error paths" begin
@@ -911,4 +933,36 @@ end
911933
end
912934
end
913935

936+
@testset "arithmetic with partly uninitialized matrices" begin
937+
@testset "$(typeof(A))" for A in (Matrix{BigFloat}(undef,2,2), Matrix{Complex{BigFloat}}(undef,2,2)')
938+
A[1,1] = A[2,2] = A[2,1] = 4
939+
B = Matrix{eltype(A)}(undef, size(A))
940+
for MT in (LowerTriangular, UnitLowerTriangular)
941+
L = MT(A)
942+
B .= 0
943+
copyto!(B, L)
944+
@test L * 2 == 2 * L == 2B
945+
@test L/2 == B/2
946+
@test 2\L == 2\B
947+
@test real(L) == real(B)
948+
@test imag(L) == imag(B)
949+
end
950+
end
951+
952+
@testset "$(typeof(A))" for A in (Matrix{BigFloat}(undef,2,2), Matrix{Complex{BigFloat}}(undef,2,2)')
953+
A[1,1] = A[2,2] = A[1,2] = 4
954+
B = Matrix{eltype(A)}(undef, size(A))
955+
for MT in (UpperTriangular, UnitUpperTriangular)
956+
U = MT(A)
957+
B .= 0
958+
copyto!(B, U)
959+
@test U * 2 == 2 * U == 2B
960+
@test U/2 == B/2
961+
@test 2\U == 2\B
962+
@test real(U) == real(B)
963+
@test imag(U) == imag(B)
964+
end
965+
end
966+
end
967+
914968
end # module TestTriangular

0 commit comments

Comments
 (0)