Skip to content

Commit 059c66e

Browse files
committed
add specialized methods for Triangular*Triangular
1 parent 8609799 commit 059c66e

File tree

1 file changed

+130
-1
lines changed

1 file changed

+130
-1
lines changed

src/triangular.jl

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Transpose{<:Any,<:StaticVecOrMat}) =
1818
transpose(transpose(B) * transpose(A))
1919

20+
const StaticULT = Union{UpperTriangular{<:Any,<:StaticMatrix},LowerTriangular{<:Any,<:StaticMatrix}}
21+
2022
@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::StaticVecOrMat) = _A_mul_B(Size(A), Size(B), A, B)
2123
@inline Base.:*(A::StaticVecOrMat, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) = _A_mul_B(Size(A), Size(B), A, B)
22-
@inline Base.:\(A::Union{UpperTriangular{<:Any,<:StaticMatrix},LowerTriangular{<:Any,<:StaticMatrix}}, B::StaticVecOrMat) = _A_ldiv_B(Size(A), Size(B), A, B)
24+
@inline Base.:*(A::StaticULT, B::StaticULT) = _A_mul_B(Size(A), Size(B), A, B)
25+
@inline Base.:\(A::StaticULT, B::StaticVecOrMat) = _A_ldiv_B(Size(A), Size(B), A, B)
2326

2427

2528
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB}
@@ -559,3 +562,129 @@ end
559562
@inbounds return similar_type(B, TAB)(tuple($(X...)))
560563
end
561564
end
565+
566+
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::UpperTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB}
567+
n = sa[1]
568+
if n != sb[1]
569+
throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])"))
570+
end
571+
572+
X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n]
573+
574+
TAB = promote_op(*, eltype(TA), eltype(TB))
575+
z = zero(TAB)
576+
577+
code = quote end
578+
for j = 1:n
579+
for i = 1:n
580+
if i > j
581+
push!(code.args, :($(X[i,j]) = $z))
582+
else
583+
ex = :(A.data[$(LinearIndices(sa)[i,i])] * B.data[$(LinearIndices(sb)[i,j])])
584+
for k = i+1:j
585+
ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])])
586+
end
587+
push!(code.args, :($(X[i,j]) = $ex))
588+
end
589+
end
590+
end
591+
592+
return quote
593+
@_inline_meta
594+
@inbounds $code
595+
return UpperTriangular(similar_type(B.data, $TAB)(tuple($(X...))))
596+
end
597+
598+
end
599+
600+
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::LowerTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB}
601+
n = sa[1]
602+
if n != sb[1]
603+
throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])"))
604+
end
605+
606+
X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n]
607+
608+
TAB = promote_op(*, eltype(TA), eltype(TB))
609+
z = zero(TAB)
610+
611+
code = quote end
612+
for j = 1:n
613+
for i = 1:n
614+
if i < j
615+
push!(code.args, :($(X[i,j]) = $z))
616+
else
617+
ex = :(A.data[$(LinearIndices(sa)[i,j])] * B.data[$(LinearIndices(sb)[j,j])])
618+
for k = j+1:i
619+
ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])])
620+
end
621+
push!(code.args, :($(X[i,j]) = $ex))
622+
end
623+
end
624+
end
625+
626+
return quote
627+
@_inline_meta
628+
@inbounds $code
629+
return LowerTriangular(similar_type(B.data, $TAB)(tuple($(X...))))
630+
end
631+
632+
end
633+
634+
635+
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::LowerTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB}
636+
n = sa[1]
637+
if n != sb[1]
638+
throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])"))
639+
end
640+
641+
X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n]
642+
643+
code = quote end
644+
for j = 1:n
645+
for i = 1:n
646+
k1 = max(i,j)
647+
ex = :(A.data[$(LinearIndices(sa)[i,k1])] * B.data[$(LinearIndices(sb)[k1,j])])
648+
for k = k1+1:n
649+
ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])])
650+
end
651+
push!(code.args, :($(X[i,j]) = $ex))
652+
end
653+
end
654+
655+
return quote
656+
@_inline_meta
657+
@inbounds $code
658+
TAB = promote_op(*, eltype(TA), eltype(TB))
659+
return similar_type(B.data, TAB)(tuple($(X...)))
660+
end
661+
662+
end
663+
664+
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::UpperTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB}
665+
n = sa[1]
666+
if n != sb[1]
667+
throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])"))
668+
end
669+
670+
X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n]
671+
672+
code = quote end
673+
for j = 1:n
674+
for i = 1:n
675+
ex = :(A.data[$(LinearIndices(sa)[i,1])] * B.data[$(LinearIndices(sb)[1,j])])
676+
for k = 2:min(i,j)
677+
ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])])
678+
end
679+
push!(code.args, :($(X[i,j]) = $ex))
680+
end
681+
end
682+
683+
return quote
684+
@_inline_meta
685+
@inbounds $code
686+
TAB = promote_op(*, eltype(TA), eltype(TB))
687+
return similar_type(B.data, TAB)(tuple($(X...)))
688+
end
689+
690+
end

0 commit comments

Comments
 (0)