Skip to content

Commit ef1b6d3

Browse files
mcognettadkarrasch
andauthored
Adding inplace multiplication for (unit)triangular matrices (#36972)
Co-authored-by: Daniel Karrasch <Daniel.Karrasch@posteo.de>
1 parent f0046a0 commit ef1b6d3

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,29 @@ mul!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = lmul!(A,
706706
mul!(C::AbstractVector, A::AbstractTriangular{<:Any,<:Adjoint}, B::Transpose{<:Any,<:AbstractVecOrMat}) = throw(MethodError(mul!, (C, A, B)))
707707
mul!(C::AbstractVector, A::AbstractTriangular{<:Any,<:Transpose}, B::Transpose{<:Any,<:AbstractVecOrMat}) = throw(MethodError(mul!, (C, A, B)))
708708

709+
# preserve triangular structure in in-place multiplication
710+
for (cty, aty, bty) in ((:UpperTriangular, :UpperTriangular, :UpperTriangular),
711+
(:UpperTriangular, :UpperTriangular, :UnitUpperTriangular),
712+
(:UpperTriangular, :UnitUpperTriangular, :UpperTriangular),
713+
(:UnitUpperTriangular, :UnitUpperTriangular, :UnitUpperTriangular),
714+
(:LowerTriangular, :LowerTriangular, :LowerTriangular),
715+
(:LowerTriangular, :LowerTriangular, :UnitLowerTriangular),
716+
(:LowerTriangular, :UnitLowerTriangular, :LowerTriangular),
717+
(:UnitLowerTriangular, :UnitLowerTriangular, :UnitLowerTriangular))
718+
@eval function mul!(C::$cty, A::$aty, B::$bty)
719+
lmul!(A, copyto!(parent(C), B))
720+
return C
721+
end
722+
723+
@eval @inline function mul!(C::$cty, A::$aty, B::$bty, alpha::Number, beta::Number)
724+
if isone(alpha) && iszero(beta)
725+
return mul!(C, A, B)
726+
else
727+
return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta))
728+
end
729+
end
730+
end
731+
709732
# direct multiplication/division
710733
for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
711734
(:UnitLowerTriangular, 'L', 'U'),

stdlib/LinearAlgebra/test/triangular.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,47 @@ end
630630
end
631631
end
632632

633+
@testset "inplace mul of appropriate types should preserve triagular structure" begin
634+
for elty1 in (Float64, ComplexF32), elty2 in (Float64, ComplexF32)
635+
T = promote_type(elty1, elty2)
636+
M1 = rand(elty1, 5, 5)
637+
M2 = rand(elty2, 5, 5)
638+
A = UpperTriangular(M1)
639+
A2 = UpperTriangular(M2)
640+
Au = UnitUpperTriangular(M1)
641+
Au2 = UnitUpperTriangular(M2)
642+
B = LowerTriangular(M1)
643+
B2 = LowerTriangular(M2)
644+
Bu = UnitLowerTriangular(M1)
645+
Bu2 = UnitLowerTriangular(M2)
646+
647+
@test mul!(similar(A), A, A)::typeof(A) == A*A
648+
@test mul!(similar(A, T), A, A2) A*A2
649+
@test mul!(similar(A, T), A2, A) A2*A
650+
@test mul!(typeof(similar(A, T))(A), A, A2, 2.0, 3.0) 2.0*A*A2 + 3.0*A
651+
@test mul!(typeof(similar(A2, T))(A2), A2, A, 2.0, 3.0) 2.0*A2*A + 3.0*A2
652+
653+
@test mul!(similar(A), A, Au)::typeof(A) == A*Au
654+
@test mul!(similar(A), Au, A)::typeof(A) == Au*A
655+
@test mul!(similar(Au), Au, Au)::typeof(Au) == Au*Au
656+
@test mul!(similar(A, T), A, Au2) A*Au2
657+
@test mul!(similar(A, T), Au2, A) Au2*A
658+
@test mul!(similar(Au2), Au2, Au2) == Au2*Au2
659+
660+
@test mul!(similar(B), B, B)::typeof(B) == B*B
661+
@test mul!(similar(B, T), B, B2) B*B2
662+
@test mul!(similar(B, T), B2, B) B2*B
663+
@test mul!(typeof(similar(B, T))(B), B, B2, 2.0, 3.0) 2.0*B*B2 + 3.0*B
664+
@test mul!(typeof(similar(B2, T))(B2), B2, B, 2.0, 3.0) 2.0*B2*B + 3.0*B2
665+
666+
@test mul!(similar(B), B, Bu)::typeof(B) == B*Bu
667+
@test mul!(similar(B), Bu, B)::typeof(B) == Bu*B
668+
@test mul!(similar(Bu), Bu, Bu)::typeof(Bu) == Bu*Bu
669+
@test mul!(similar(B, T), B, Bu2) B*Bu2
670+
@test mul!(similar(B, T), Bu2, B) Bu2*B
671+
end
672+
end
673+
633674
@testset "special printing of Lower/UpperTriangular" begin
634675
@test occursin(r"3×3 (LinearAlgebra\.)?LowerTriangular{Int64, Matrix{Int64}}:\n 2 ⋅ ⋅\n 2 2 ⋅\n 2 2 2",
635676
sprint(show, MIME"text/plain"(), LowerTriangular(2ones(Int64,3,3))))

0 commit comments

Comments
 (0)