From 2282328478f736d7c7f1dbebb53bef43623e71f4 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 27 Apr 2025 23:07:39 +0530 Subject: [PATCH 1/5] Out-of-place `triu`/`tril` for `Symmetric` in each branch --- src/symmetric.jl | 24 ++++++++++++------------ src/triangular.jl | 5 +++++ test/symmetric.jl | 23 +++++++++++++++++++++++ 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/src/symmetric.jl b/src/symmetric.jl index 804a063b..c768f091 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -474,25 +474,25 @@ Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo) # tril/triu function tril(A::Hermitian, k::Integer=0) if A.uplo == 'U' && k <= 0 - return tril!(copy(A.data'),k) + return tril_maybe_inplace(copy(A.data'),k) elseif A.uplo == 'U' && k > 0 - return tril!(copy(A.data'),-1) + tril!(triu(A.data),k) + return tril_maybe_inplace(copy(A.data'),-1) + tril_maybe_inplace(triu(A.data),k) elseif A.uplo == 'L' && k <= 0 return tril(A.data,k) else - return tril(A.data,-1) + tril!(triu!(copy(A.data')),k) + return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(A.data')),k) end end function tril(A::Symmetric, k::Integer=0) if A.uplo == 'U' && k <= 0 - return tril!(copy(transpose(A.data)),k) + return tril_maybe_inplace(copy(transpose(A.data)),k) elseif A.uplo == 'U' && k > 0 - return tril!(copy(transpose(A.data)),-1) + tril!(triu(A.data),k) + return tril_maybe_inplace(copy(transpose(A.data)),-1) + tril_maybe_inplace(triu(A.data),k) elseif A.uplo == 'L' && k <= 0 return tril(A.data,k) else - return tril(A.data,-1) + tril!(triu!(copy(transpose(A.data))),k) + return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(transpose(A.data))),k) end end @@ -500,11 +500,11 @@ function triu(A::Hermitian, k::Integer=0) if A.uplo == 'U' && k >= 0 return triu(A.data,k) elseif A.uplo == 'U' && k < 0 - return triu(A.data,1) + triu!(tril!(copy(A.data')),k) + return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(A.data')),k) elseif A.uplo == 'L' && k >= 0 - return triu!(copy(A.data'),k) + return triu_maybe_inplace(copy(A.data'),k) else - return triu!(copy(A.data'),1) + triu!(tril(A.data),k) + return triu_maybe_inplace(copy(A.data'),1) + triu_maybe_inplace(tril(A.data),k) end end @@ -512,11 +512,11 @@ function triu(A::Symmetric, k::Integer=0) if A.uplo == 'U' && k >= 0 return triu(A.data,k) elseif A.uplo == 'U' && k < 0 - return triu(A.data,1) + triu!(tril!(copy(transpose(A.data))),k) + return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(transpose(A.data))),k) elseif A.uplo == 'L' && k >= 0 - return triu!(copy(transpose(A.data)),k) + return triu_maybe_inplace(copy(transpose(A.data)),k) else - return triu!(copy(transpose(A.data)),1) + triu!(tril(A.data),k) + return triu_maybe_inplace(copy(transpose(A.data)),1) + triu_maybe_inplace(tril(A.data),k) end end diff --git a/src/triangular.jl b/src/triangular.jl index 26ad4204..7c99f85a 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -484,6 +484,11 @@ function tril!(A::UnitLowerTriangular, k::Integer=0) return tril!(LowerTriangular(A.data), k) end +tril_maybe_inplace(A, k::Integer=0) = tril(A, k) +triu_maybe_inplace(A, k::Integer=0) = triu(A, k) +tril_maybe_inplace(A::StridedMatrix, k::Integer=0) = tril!(A, k) +triu_maybe_inplace(A::StridedMatrix, k::Integer=0) = triu!(A, k) + adjoint(A::LowerTriangular) = UpperTriangular(adjoint(A.data)) adjoint(A::UpperTriangular) = LowerTriangular(adjoint(A.data)) adjoint(A::UnitLowerTriangular) = UnitUpperTriangular(adjoint(A.data)) diff --git a/test/symmetric.jl b/test/symmetric.jl index ae97b453..0f8dec3f 100644 --- a/test/symmetric.jl +++ b/test/symmetric.jl @@ -1206,4 +1206,27 @@ end end end +@testset "triu/tril with immutable arrays" begin + struct ImmutableMatrix{T,A<:AbstractMatrix{T}} <: AbstractMatrix{T} + a :: A + end + Base.size(A::ImmutableMatrix) = size(A.a) + Base.getindex(A::ImmutableMatrix, i::Int, j::Int) = getindex(A.a, i, j) + Base.copy(A::ImmutableMatrix) = A + LinearAlgebra.adjoint(A::ImmutableMatrix) = ImmutableMatrix(adjoint(A.a)) + LinearAlgebra.transpose(A::ImmutableMatrix) = ImmutableMatrix(transpose(A.a)) + + A = ImmutableMatrix([1 2; 3 4]) + for T in (Symmetric, Hermitian), uplo in (:U, :L) + H = T(A, uplo) + MH = Matrix(H) + @test triu(H,-1) == triu(MH,-1) + @test triu(H) == triu(MH) + @test triu(H,1) == triu(MH,1) + @test tril(H,1) == tril(MH,1) + @test tril(H) == tril(MH) + @test tril(H,-1) == tril(MH,-1) + end +end + end # module TestSymmetric From 6feca41547913241b91676e7c801a4294983650c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 27 Apr 2025 23:09:56 +0530 Subject: [PATCH 2/5] Undo `_conjugation` change --- src/symmetric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/symmetric.jl b/src/symmetric.jl index c768f091..51baa4a6 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -293,7 +293,7 @@ Base.dataids(A::HermOrSym) = Base.dataids(parent(A)) Base.unaliascopy(A::Hermitian) = Hermitian(Base.unaliascopy(parent(A)), sym_uplo(A.uplo)) Base.unaliascopy(A::Symmetric) = Symmetric(Base.unaliascopy(parent(A)), sym_uplo(A.uplo)) -_conjugation(::Union{Symmetric, Hermitian{<:Real}}) = transpose +_conjugation(::Symmetric) = transpose _conjugation(::Hermitian) = adjoint diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo)) From 68b4acec698994c6a118b6290e5118ec2cf274c3 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 16 Jun 2025 14:10:01 +0530 Subject: [PATCH 3/5] Use `ImmutableArrays` test helper --- test/symmetric.jl | 11 +---------- test/testhelpers/ImmutableArrays.jl | 3 +++ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/test/symmetric.jl b/test/symmetric.jl index 0f8dec3f..a9d1a883 100644 --- a/test/symmetric.jl +++ b/test/symmetric.jl @@ -1207,16 +1207,7 @@ end end @testset "triu/tril with immutable arrays" begin - struct ImmutableMatrix{T,A<:AbstractMatrix{T}} <: AbstractMatrix{T} - a :: A - end - Base.size(A::ImmutableMatrix) = size(A.a) - Base.getindex(A::ImmutableMatrix, i::Int, j::Int) = getindex(A.a, i, j) - Base.copy(A::ImmutableMatrix) = A - LinearAlgebra.adjoint(A::ImmutableMatrix) = ImmutableMatrix(adjoint(A.a)) - LinearAlgebra.transpose(A::ImmutableMatrix) = ImmutableMatrix(transpose(A.a)) - - A = ImmutableMatrix([1 2; 3 4]) + A = ImmutableArray([1 2; 3 4]) for T in (Symmetric, Hermitian), uplo in (:U, :L) H = T(A, uplo) MH = Matrix(H) diff --git a/test/testhelpers/ImmutableArrays.jl b/test/testhelpers/ImmutableArrays.jl index 8f2d23be..014e8110 100644 --- a/test/testhelpers/ImmutableArrays.jl +++ b/test/testhelpers/ImmutableArrays.jl @@ -28,4 +28,7 @@ AbstractArray{T,N}(A::ImmutableArray{S,N}) where {S,T,N} = ImmutableArray(Abstra Base.copy(A::ImmutableArray) = ImmutableArray(copy(A.data)) Base.zero(A::ImmutableArray) = ImmutableArray(zero(A.data)) +Base.adjoint(A::ImmutableArray) = ImmutableArray(adjoint(A.data)) +Base.transpose(A::ImmutableArray) = ImmutableArray(transpose(A.data)) + end From 9b9ab75bd0caf3855be8f259be0be9e7474808f0 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 16 Jun 2025 14:16:37 +0530 Subject: [PATCH 4/5] Revert `_conjugation` change --- src/symmetric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/symmetric.jl b/src/symmetric.jl index 51baa4a6..c768f091 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -293,7 +293,7 @@ Base.dataids(A::HermOrSym) = Base.dataids(parent(A)) Base.unaliascopy(A::Hermitian) = Hermitian(Base.unaliascopy(parent(A)), sym_uplo(A.uplo)) Base.unaliascopy(A::Symmetric) = Symmetric(Base.unaliascopy(parent(A)), sym_uplo(A.uplo)) -_conjugation(::Symmetric) = transpose +_conjugation(::Union{Symmetric, Hermitian{<:Real}}) = transpose _conjugation(::Hermitian) = adjoint diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo)) From bbcfa7d75d3904c142c2a9b5ee38f54b8c68845a Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 20 Jun 2025 16:12:41 +0530 Subject: [PATCH 5/5] Preserve wrapper in AbstractMatrix constructor for Adjoint/Transpose --- src/adjtrans.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/adjtrans.jl b/src/adjtrans.jl index 06848d8e..68acb200 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -332,6 +332,10 @@ wrapperop(::Transpose) = transpose _wrapperop(x) = wrapperop(x) _wrapperop(::Adjoint{<:Real}) = transpose +# equivalent to wrapperop, but returns the type of the wrapper +wrappertype(::Adjoint) = Adjoint +wrappertype(::Transpose) = Transpose + # the following fallbacks can be removed if Adjoint/Transpose are restricted to AbstractVecOrMat size(A::AdjOrTrans) = reverse(size(A.parent)) axes(A::AdjOrTrans) = reverse(axes(A.parent)) @@ -391,7 +395,7 @@ similar(A::AdjOrTrans, ::Type{T}) where {T} = similar(A.parent, T, axes(A)) similar(A::AdjOrTrans, ::Type{T}, dims::Dims{N}) where {T,N} = similar(A.parent, T, dims) # AbstractMatrix{T} constructor for adjtrans vector: preserve wrapped type -AbstractMatrix{T}(A::AdjOrTransAbsVec) where {T} = wrapperop(A)(AbstractVector{T}(A.parent)) +AbstractMatrix{T}(A::AdjOrTransAbsVec) where {T} = wrappertype(A)(AbstractVector{T}(A.parent)) # sundry basic definitions parent(A::AdjOrTrans) = A.parent