diff --git a/Project.toml b/Project.toml index e571881d..6228868c 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "90137ffa-7385-5640-81b9-e52037218182" version = "1.5.6" [deps] +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -10,6 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] julia = "1.6" +ConstructionBase = "1" StaticArraysCore = "~1.3.0" [extras] diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index f4ad622d..c531dc95 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -10,6 +10,8 @@ import Base: getindex, setindex!, size, similar, vec, show, length, convert, pro import Statistics: mean +import ConstructionBase: constructorof + using Random import Random: rand, randn, randexp, rand!, randn!, randexp! using Core.Compiler: return_type @@ -128,6 +130,7 @@ include("svd.jl") include("qr.jl") include("deque.jl") include("flatten.jl") +include("constructorof.jl") include("io.jl") include("pinv.jl") diff --git a/src/constructorof.jl b/src/constructorof.jl new file mode 100644 index 00000000..4880f66c --- /dev/null +++ b/src/constructorof.jl @@ -0,0 +1,9 @@ +# keep the size when reconstructing arrays +# eltype can be different +constructorof(sa::Type{<:SArray{S}}) where {S} = SArray{S} +constructorof(sa::Type{<:MArray{S}}) where {S} = MArray{S} + +# don't keep neither size nor eltype for vectors: +# both are unambiguously determined by the values +constructorof(::Type{<:SVector}) = SVector +constructorof(::Type{<:MVector}) = MVector diff --git a/test/constructorof.jl b/test/constructorof.jl new file mode 100644 index 00000000..3c40cd78 --- /dev/null +++ b/test/constructorof.jl @@ -0,0 +1,25 @@ +using StaticArrays +using Test +using ConstructionBase: constructorof + +@testset "constructorof" begin + sa = @SVector [2, 4, 6, 8] + sa2 = constructorof(typeof(sa))((3.0, 5.0, 7.0, 9.0)) + @test sa2 === @SVector [3.0, 5.0, 7.0, 9.0] + + ma = @MMatrix [2.0 4.0; 6.0 8.0] + ma2 = constructorof(typeof(ma))((1, 2, 3, 4)) + @test ma2 isa MArray{Tuple{2,2},Int,2,4} + @test all(ma2 .=== @MMatrix [1 3; 2 4]) + + for T in (SVector, MVector) + @test constructorof(T)((1, 2, 3))::T == T((1, 2, 3)) + @test constructorof(T{3})((1, 2, 3))::T == T((1, 2, 3)) + @test constructorof(T{3})((1, 2))::T == T((1, 2)) + @test constructorof(T{3, Symbol})((1, 2, 3))::T == T((1, 2, 3)) + @test constructorof(T{3, Symbol})((1, 2))::T == T((1, 2)) + @test constructorof(T{3, X} where {X})((1, 2, 3))::T == T((1, 2, 3)) + @test constructorof(T{3, X} where {X})((1, 2))::T == T((1, 2)) + @test constructorof(T{X, Symbol} where {X})((1, 2, 3))::T == T((1, 2, 3)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index bd9e4502..5c8fb00c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -78,6 +78,7 @@ if TEST_GROUP ∈ ["", "all", "group-B"] addtests("chol.jl") # hermitian_type(::Type{Any}) for block algorithm addtests("deque.jl") addtests("flatten.jl") + addtests("constructorof.jl") addtests("io.jl") addtests("svd.jl") end