From 6f5c86221227d7d8d463497bb7ace362f0765671 Mon Sep 17 00:00:00 2001 From: jishnub Date: Wed, 16 Jun 2021 23:25:29 +0400 Subject: [PATCH 1/3] axes(::SDiagonal) is statically sized vector indexing for SDiagonal produces SArrays --- src/SDiagonal.jl | 7 +++++++ src/indexing.jl | 26 ++++++++++++++++++++++++++ test/SDiagonal.jl | 11 +++++++++++ 3 files changed, 44 insertions(+) diff --git a/src/SDiagonal.jl b/src/SDiagonal.jl index d1e92caa..1112f0fb 100644 --- a/src/SDiagonal.jl +++ b/src/SDiagonal.jl @@ -17,6 +17,11 @@ SDiagonal(a::StaticMatrix{N,N,T}) where {N,T} = Diagonal(diag(a)) size(::Type{<:SDiagonal{N}}) where {N} = (N,N) size(::Type{<:SDiagonal{N}}, d::Int) where {N} = d > 2 ? 1 : N +Base.axes(D::SDiagonal) = (ax = axes(diag(D), 1); (ax, ax)) +Base.axes(D::SDiagonal, d) = d <= 2 ? axes(D)[d] : SOneTo(1) + +Base.reshape(a::SDiagonal, s::Tuple{SOneTo,Vararg{SOneTo}}) = reshape(a, homogenize_shape(s)) + # define specific methods to avoid allocating mutable arrays \(D::SDiagonal, b::AbstractVector) = D.diag .\ b \(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity @@ -56,3 +61,5 @@ function inv(D::SDiagonal) check_singular(D) SDiagonal(inv.(D.diag)) end + +Base.copy(D::SDiagonal) = Diagonal(copy(diag(D))) diff --git a/src/indexing.jl b/src/indexing.jl index dffb0cef..6f87163e 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -377,3 +377,29 @@ Base.unsafe_view(A::AbstractArray, i1::StaticIndexing, indices::StaticIndexing.. # the tuple indices has to have at least one element to prevent infinite # recursion when viewing a zero-dimensional array (see issue #705) Base.SubArray(A::AbstractArray, indices::Tuple{StaticIndexing, Vararg{StaticIndexing}}) = Base.SubArray(A, map(unwrap, indices)) + +########################################################### +# SDiagonal +########################################################### + +# SDiagonal uses Cartesian indexing, and the canonical indexing methods shadow getindex for Diagonal +# these are needed for ambiguity resolution +@inline function getindex(D::SDiagonal, i::Int, j::Int) + @boundscheck checkbounds(D, i, j) + if i == j + @inbounds r = diag(D)[i] + else + r = LinearAlgebra.diagzero(D, i, j) + end + r +end +@inline function getindex(D::SDiagonal, i::Int...) + @boundscheck checkbounds(D, i...) + @inbounds r = D[eachindex(D)[i...]] + r +end +# Ensure that vector indexing with static types lead to SArrays +@propagate_inbounds function getindex(a::SDiagonal, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon}...) + ar = reshape(a, Val(length(inds))) + _getindex(ar, index_sizes(Size(ar), inds...), inds) +end diff --git a/test/SDiagonal.jl b/test/SDiagonal.jl index ff94f849..135fd576 100644 --- a/test/SDiagonal.jl +++ b/test/SDiagonal.jl @@ -70,6 +70,15 @@ using StaticArrays, Test, LinearAlgebra @test length(m) === 4*4 + m2 = SMatrix{4,4}(m) + @test axes(m) === axes(m2) + @test axes(m, 1) === axes(m2, 1) + @test axes(m, 3) == SOneTo(1) + + @test m[:, 1] === SVector{4}(m[1,1], 0, 0, 0) + @test m[:, :] === m2 + @test m[2, 2, 1] === m[2, 2] + @test_throws Exception m[1] = 1 b = @SVector [2,-1,2,1] @@ -114,5 +123,7 @@ using StaticArrays, Test, LinearAlgebra @test m + zero(m) == m @test m + zero(typeof(m)) == m + + @test copy(m) === m end end From 9dd183dabc001173d5da400e9cce8af45d38fdd5 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 27 Jun 2022 18:58:23 +0530 Subject: [PATCH 2/3] invoke instead of duplicating getindex --- src/SDiagonal.jl | 1 - src/indexing.jl | 12 ++---------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/SDiagonal.jl b/src/SDiagonal.jl index 1112f0fb..d726eaba 100644 --- a/src/SDiagonal.jl +++ b/src/SDiagonal.jl @@ -17,7 +17,6 @@ SDiagonal(a::StaticMatrix{N,N,T}) where {N,T} = Diagonal(diag(a)) size(::Type{<:SDiagonal{N}}) where {N} = (N,N) size(::Type{<:SDiagonal{N}}, d::Int) where {N} = d > 2 ? 1 : N -Base.axes(D::SDiagonal) = (ax = axes(diag(D), 1); (ax, ax)) Base.axes(D::SDiagonal, d) = d <= 2 ? axes(D)[d] : SOneTo(1) Base.reshape(a::SDiagonal, s::Tuple{SOneTo,Vararg{SOneTo}}) = reshape(a, homogenize_shape(s)) diff --git a/src/indexing.jl b/src/indexing.jl index 6f87163e..98315e4d 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -385,18 +385,10 @@ Base.SubArray(A::AbstractArray, indices::Tuple{StaticIndexing, Vararg{StaticInde # SDiagonal uses Cartesian indexing, and the canonical indexing methods shadow getindex for Diagonal # these are needed for ambiguity resolution @inline function getindex(D::SDiagonal, i::Int, j::Int) - @boundscheck checkbounds(D, i, j) - if i == j - @inbounds r = diag(D)[i] - else - r = LinearAlgebra.diagzero(D, i, j) - end - r + invoke(getindex, Tuple{Diagonal, Int, Int}, D, i, j) end @inline function getindex(D::SDiagonal, i::Int...) - @boundscheck checkbounds(D, i...) - @inbounds r = D[eachindex(D)[i...]] - r + invoke(getindex, Tuple{Diagonal, Vararg{Int}}, D, i...) end # Ensure that vector indexing with static types lead to SArrays @propagate_inbounds function getindex(a::SDiagonal, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon}...) From 095186a3cd900634a6648e5fe1721f55cccf2400 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 28 Jun 2022 09:37:37 +0530 Subject: [PATCH 3/3] check for axes inference --- test/SDiagonal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/SDiagonal.jl b/test/SDiagonal.jl index 135fd576..7d1090c6 100644 --- a/test/SDiagonal.jl +++ b/test/SDiagonal.jl @@ -71,7 +71,7 @@ using StaticArrays, Test, LinearAlgebra @test length(m) === 4*4 m2 = SMatrix{4,4}(m) - @test axes(m) === axes(m2) + @test (@inferred axes(m)) === axes(m2) @test axes(m, 1) === axes(m2, 1) @test axes(m, 3) == SOneTo(1)