Skip to content

Commit f127c4a

Browse files
authored
Merge pull request #511 from kleinschmidt/dfk/tritri
add specialized methods for Triangular * Triangular
2 parents d98e49b + c90c055 commit f127c4a

File tree

2 files changed

+171
-19
lines changed

2 files changed

+171
-19
lines changed

src/triangular.jl

Lines changed: 148 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Transpose{<:Any,<:StaticVecOrMat}) =
1818
transpose(transpose(B) * transpose(A))
1919

20+
const StaticULT = Union{UpperTriangular{<:Any,<:StaticMatrix},LowerTriangular{<:Any,<:StaticMatrix}}
21+
2022
@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::StaticVecOrMat) = _A_mul_B(Size(A), Size(B), A, B)
2123
@inline Base.:*(A::StaticVecOrMat, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) = _A_mul_B(Size(A), Size(B), A, B)
22-
@inline Base.:\(A::Union{UpperTriangular{<:Any,<:StaticMatrix},LowerTriangular{<:Any,<:StaticMatrix}}, B::StaticVecOrMat) = _A_ldiv_B(Size(A), Size(B), A, B)
24+
@inline Base.:*(A::StaticULT, B::StaticULT) = _A_mul_B(Size(A), Size(B), A, B)
25+
@inline Base.:\(A::StaticULT, B::StaticVecOrMat) = _A_ldiv_B(Size(A), Size(B), A, B)
2326

2427

2528
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB}
@@ -31,7 +34,7 @@
3134

3235
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
3336

34-
code = quote end
37+
code = Expr(:block)
3538
for j = 1:n
3639
for i = 1:m
3740
ex = :(A.data[$(LinearIndices(sa)[i, i])]*B[$(LinearIndices(sb)[i, j])])
@@ -59,7 +62,7 @@ end
5962

6063
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
6164

62-
code = quote end
65+
code = Expr(:block)
6366
for j = 1:n
6467
for i = m:-1:1
6568
ex = :(A.data[$(LinearIndices(sa)[i, i])]'*B[$(LinearIndices(sb)[i, j])])
@@ -87,7 +90,7 @@ end
8790

8891
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
8992

90-
code = quote end
93+
code = Expr(:block)
9194
for j = 1:n
9295
for i = m:-1:1
9396
ex = :(transpose(A.data[$(LinearIndices(sa)[i, i])])*B[$(LinearIndices(sb)[i, j])])
@@ -115,7 +118,7 @@ end
115118

116119
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
117120

118-
code = quote end
121+
code = Expr(:block)
119122
for j = 1:n
120123
for i = m:-1:1
121124
ex = :(A.data[$(LinearIndices(sa)[i, i])]*B[$(LinearIndices(sb)[i, j])])
@@ -143,7 +146,7 @@ end
143146

144147
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
145148

146-
code = quote end
149+
code = Expr(:block)
147150
for j = 1:n
148151
for i = 1:m
149152
ex = :(A.data[$(LinearIndices(sa)[i, i])]'*B[$(LinearIndices(sb)[i, j])])
@@ -171,7 +174,7 @@ end
171174

172175
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
173176

174-
code = quote end
177+
code = Expr(:block)
175178
for j = 1:n
176179
for i = 1:m
177180
ex = :(transpose(A.data[$(LinearIndices(sa)[i, i])])*B[$(LinearIndices(sb)[i, j])])
@@ -203,7 +206,7 @@ end
203206

204207
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
205208

206-
code = quote end
209+
code = Expr(:block)
207210
for i = 1:m
208211
for j = n:-1:1
209212
ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])])
@@ -235,7 +238,7 @@ end
235238

236239
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
237240

238-
code = quote end
241+
code = Expr(:block)
239242
for i = 1:m
240243
for j = 1:n
241244
ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])]')
@@ -262,7 +265,7 @@ end
262265

263266
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
264267

265-
code = quote end
268+
code = Expr(:block)
266269
for i = 1:m
267270
for j = 1:n
268271
ex = :(A[$(LinearIndices(sa)[i, j])]*transpose(B[$(LinearIndices(sb)[j, j])]))
@@ -294,7 +297,7 @@ end
294297

295298
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
296299

297-
code = quote end
300+
code = Expr(:block)
298301
for i = 1:m
299302
for j = 1:n
300303
ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])])
@@ -326,7 +329,7 @@ end
326329

327330
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
328331

329-
code = quote end
332+
code = Expr(:block)
330333
for i = 1:m
331334
for j = n:-1:1
332335
ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])]')
@@ -353,7 +356,7 @@ end
353356

354357
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
355358

356-
code = quote end
359+
code = Expr(:block)
357360
for i = 1:m
358361
for j = n:-1:1
359362
ex = :(A[$(LinearIndices(sa)[i, j])]*transpose(B[$(LinearIndices(sb)[j, j])]))
@@ -382,7 +385,7 @@ end
382385
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
383386
init = [:($(X[i,j]) = B[$(LinearIndices(sb)[i, j])]) for i = 1:m, j = 1:n]
384387

385-
code = quote end
388+
code = Expr(:block)
386389
for k = 1:n
387390
for j = m:-1:1
388391
if k == 1
@@ -414,7 +417,7 @@ end
414417
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
415418
init = [:($(X[i,j]) = B[$(LinearIndices(sb)[i, j])]) for i = 1:m, j = 1:n]
416419

417-
code = quote end
420+
code = Expr(:block)
418421
for k = 1:n
419422
for j = 1:m
420423
if k == 1
@@ -445,7 +448,7 @@ end
445448

446449
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
447450

448-
code = quote end
451+
code = Expr(:block)
449452
for k = 1:n
450453
for j = 1:m
451454
ex = :(B[$(LinearIndices(sb)[j, k])])
@@ -476,7 +479,7 @@ end
476479

477480
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
478481

479-
code = quote end
482+
code = Expr(:block)
480483
for k = 1:n
481484
for j = 1:m
482485
ex = :(B[$(LinearIndices(sb)[j, k])])
@@ -507,7 +510,7 @@ end
507510

508511
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
509512

510-
code = quote end
513+
code = Expr(:block)
511514
for k = 1:n
512515
for j = m:-1:1
513516
ex = :(B[$(LinearIndices(sb)[j, k])])
@@ -538,7 +541,7 @@ end
538541

539542
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
540543

541-
code = quote end
544+
code = Expr(:block)
542545
for k = 1:n
543546
for j = m:-1:1
544547
ex = :(B[$(LinearIndices(sb)[j, k])])
@@ -559,3 +562,129 @@ end
559562
@inbounds return similar_type(B, TAB)(tuple($(X...)))
560563
end
561564
end
565+
566+
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::UpperTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB}
567+
n = sa[1]
568+
if n != sb[1]
569+
throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])"))
570+
end
571+
572+
X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n]
573+
574+
TAB = promote_op(*, eltype(TA), eltype(TB))
575+
z = zero(TAB)
576+
577+
code = Expr(:block)
578+
for j = 1:n
579+
for i = 1:n
580+
if i > j
581+
push!(code.args, :($(X[i,j]) = $z))
582+
else
583+
ex = :(A.data[$(LinearIndices(sa)[i,i])] * B.data[$(LinearIndices(sb)[i,j])])
584+
for k = i+1:j
585+
ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])])
586+
end
587+
push!(code.args, :($(X[i,j]) = $ex))
588+
end
589+
end
590+
end
591+
592+
return quote
593+
@_inline_meta
594+
@inbounds $code
595+
return UpperTriangular(similar_type(B.data, $TAB)(tuple($(X...))))
596+
end
597+
598+
end
599+
600+
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::LowerTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB}
601+
n = sa[1]
602+
if n != sb[1]
603+
throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])"))
604+
end
605+
606+
X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n]
607+
608+
TAB = promote_op(*, eltype(TA), eltype(TB))
609+
z = zero(TAB)
610+
611+
code = Expr(:block)
612+
for j = 1:n
613+
for i = 1:n
614+
if i < j
615+
push!(code.args, :($(X[i,j]) = $z))
616+
else
617+
ex = :(A.data[$(LinearIndices(sa)[i,j])] * B.data[$(LinearIndices(sb)[j,j])])
618+
for k = j+1:i
619+
ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])])
620+
end
621+
push!(code.args, :($(X[i,j]) = $ex))
622+
end
623+
end
624+
end
625+
626+
return quote
627+
@_inline_meta
628+
@inbounds $code
629+
return LowerTriangular(similar_type(B.data, $TAB)(tuple($(X...))))
630+
end
631+
632+
end
633+
634+
635+
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::LowerTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB}
636+
n = sa[1]
637+
if n != sb[1]
638+
throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])"))
639+
end
640+
641+
X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n]
642+
643+
code = Expr(:block)
644+
for j = 1:n
645+
for i = 1:n
646+
k1 = max(i,j)
647+
ex = :(A.data[$(LinearIndices(sa)[i,k1])] * B.data[$(LinearIndices(sb)[k1,j])])
648+
for k = k1+1:n
649+
ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])])
650+
end
651+
push!(code.args, :($(X[i,j]) = $ex))
652+
end
653+
end
654+
655+
return quote
656+
@_inline_meta
657+
@inbounds $code
658+
TAB = promote_op(*, eltype(TA), eltype(TB))
659+
return similar_type(B.data, TAB)(tuple($(X...)))
660+
end
661+
662+
end
663+
664+
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::UpperTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB}
665+
n = sa[1]
666+
if n != sb[1]
667+
throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])"))
668+
end
669+
670+
X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n]
671+
672+
code = Expr(:block)
673+
for j = 1:n
674+
for i = 1:n
675+
ex = :(A.data[$(LinearIndices(sa)[i,1])] * B.data[$(LinearIndices(sb)[1,j])])
676+
for k = 2:min(i,j)
677+
ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])])
678+
end
679+
push!(code.args, :($(X[i,j]) = $ex))
680+
end
681+
end
682+
683+
return quote
684+
@_inline_meta
685+
@inbounds $code
686+
TAB = promote_op(*, eltype(TA), eltype(TB))
687+
return similar_type(B.data, TAB)(tuple($(X...)))
688+
end
689+
690+
end

test/triangular.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,29 @@ end
8282
end
8383
end
8484

85+
@testset "Triangular-triangular multiplication" begin
86+
for n in (1, 2, 3, 4),
87+
eltyA in (Float64, ComplexF64, Int),
88+
eltyB in (Float64, ComplexF64, Int),
89+
(ta, uploa) in ((UpperTriangular, :U), (LowerTriangular, :L)),
90+
(tb, uplob) in ((UpperTriangular, :U), (LowerTriangular, :L))
91+
92+
A = ta(eltyA == Int ? rand(1:7, n, n) : rand(eltyA, n, n))
93+
B = tb(eltyB == Int ? rand(1:7, n, n) : rand(eltyB, n, n))
94+
95+
SA = ta(SMatrix{n,n}(A.data))
96+
SB = tb(SMatrix{n,n}(B.data))
97+
98+
eltyAB = Base.promote_op(*, eltyA, eltyB)
99+
100+
@test SA*SB A*B
101+
@test eltype(SA*SB) == eltyAB
102+
@test SA*SB isa (ta===tb ? ta : SMatrix)
103+
104+
end
105+
106+
end
107+
85108
@testset "Triangular-matrix division" begin
86109
for n in (1, 2, 3, 4),
87110
eltyA in (Float64, ComplexF64, Int),

0 commit comments

Comments
 (0)