Skip to content

Commit b953e61

Browse files
jishnubKristofferC
authored andcommitted
copytrito! for triangular matrices (#56620)
This does two things: 1. Forward `copytrito!` for triangular matrices to the parent in case the specified `uplo` corresponds to the stored part. This works because these matrices share their elements with the parents for the stored part. 2. Make `copytrito!` only copy the diagonal if the `uplo` corresponds to the non-stored part. This makes `copytrito!` involving a triangular matrix equivalent to that involving its parent if the filled part is copied, and O(N) otherwise. Examples of improvements in performance: ```julia julia> using LinearAlgebra julia> A1 = UpperTriangular(rand(400,400)); julia> A2 = similar(A1); julia> @Btime copytrito!($A2, $A1, 'U'); 70.753 μs (0 allocations: 0 bytes) # nightly v"1.12.0-DEV.1657" 26.143 μs (0 allocations: 0 bytes) # this PR julia> @Btime copytrito!(parent($A2), $A1, 'U'); 56.025 μs (0 allocations: 0 bytes) # nightly 26.633 μs (0 allocations: 0 bytes) # this PR ```
1 parent 1207f3c commit b953e61

File tree

3 files changed

+58
-5
lines changed

3 files changed

+58
-5
lines changed

src/generic.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,17 +2069,20 @@ function copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar)
20692069
require_one_based_indexing(A, B)
20702070
BLAS.chkuplo(uplo)
20712071
m,n = size(A)
2072-
m1,n1 = size(B)
20732072
A = Base.unalias(B, A)
20742073
if uplo == 'U'
2075-
LAPACK.lacpy_size_check((m1, n1), (n < m ? n : m, n))
2074+
LAPACK.lacpy_size_check(size(B), (n < m ? n : m, n))
20762075
for j in axes(A,2), i in axes(A,1)[begin : min(j,end)]
2077-
@inbounds B[i,j] = A[i,j]
2076+
# extract the parents for UpperTriangular matrices
2077+
Bv, Av = uppertridata(B), uppertridata(A)
2078+
@inbounds Bv[i,j] = Av[i,j]
20782079
end
20792080
else # uplo == 'L'
2080-
LAPACK.lacpy_size_check((m1, n1), (m, m < n ? m : n))
2081+
LAPACK.lacpy_size_check(size(B), (m, m < n ? m : n))
20812082
for j in axes(A,2), i in axes(A,1)[j:end]
2082-
@inbounds B[i,j] = A[i,j]
2083+
# extract the parents for LowerTriangular matrices
2084+
Bv, Av = lowertridata(B), lowertridata(A)
2085+
@inbounds Bv[i,j] = Av[i,j]
20832086
end
20842087
end
20852088
return B

src/triangular.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,36 @@ end
633633
return dest
634634
end
635635

636+
Base.@constprop :aggressive function copytrito_triangular!(Bdata, Adata, uplo, uplomatch, sz)
637+
if uplomatch
638+
copytrito!(Bdata, Adata, uplo)
639+
else
640+
BLAS.chkuplo(uplo)
641+
LAPACK.lacpy_size_check(size(Bdata), sz)
642+
# only the diagonal is copied in this case
643+
copyto!(diagview(Bdata), diagview(Adata))
644+
end
645+
return Bdata
646+
end
647+
648+
function copytrito!(B::UpperTriangular, A::UpperTriangular, uplo::AbstractChar)
649+
m,n = size(A)
650+
copytrito_triangular!(B.data, A.data, uplo, uplo == 'U', (m, m < n ? m : n))
651+
return B
652+
end
653+
function copytrito!(B::LowerTriangular, A::LowerTriangular, uplo::AbstractChar)
654+
m,n = size(A)
655+
copytrito_triangular!(B.data, A.data, uplo, uplo == 'L', (n < m ? n : m, n))
656+
return B
657+
end
658+
659+
uppertridata(A) = A
660+
lowertridata(A) = A
661+
# we restrict these specializations only to strided matrices to avoid cases where an UpperTriangular type
662+
# doesn't share its indexing with the parent
663+
uppertridata(A::UpperTriangular{<:Any, <:StridedMatrix}) = parent(A)
664+
lowertridata(A::LowerTriangular{<:Any, <:StridedMatrix}) = parent(A)
665+
636666
@inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) =
637667
@stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta))
638668
@inline _lscale_add!(A::AbstractTriangular, B::Number, C::AbstractTriangular, alpha::Number, beta::Number) =

test/triangular.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,4 +1396,24 @@ end
13961396
@test exp(log(M)) M
13971397
end
13981398

1399+
@testset "copytrito!" begin
1400+
for T in (UpperTriangular, LowerTriangular)
1401+
M = Matrix{BigFloat}(undef, 2, 2)
1402+
M[1,1] = M[2,2] = 3
1403+
U = T(M)
1404+
isupper = U isa UpperTriangular
1405+
M[1+!isupper, 1+isupper] = 4
1406+
uplo, loup = U isa UpperTriangular ? ('U', 'L') : ('L', 'U' )
1407+
@test copytrito!(similar(U), U, uplo) == U
1408+
@test copytrito!(zero(M), U, uplo) == U
1409+
@test copytrito!(similar(U), Array(U), uplo) == U
1410+
@test copytrito!(zero(U), U, loup) == Diagonal(U)
1411+
@test copytrito!(similar(U), MyTriangular(U), uplo) == U
1412+
@test copytrito!(zero(M), MyTriangular(U), uplo) == U
1413+
Ubig = T(similar(M, (3,3)))
1414+
copytrito!(Ubig, U, uplo)
1415+
@test Ubig[axes(U)...] == U
1416+
end
1417+
end
1418+
13991419
end # module TestTriangular

0 commit comments

Comments
 (0)