Skip to content

Commit f445984

Browse files
authored
Merge pull request #507 from JuliaArrays/flatten
Special case flatten of iterators of static arrays as done for tuples
2 parents f127c4a + 76eea3c commit f445984

File tree

4 files changed

+23
-0
lines changed

4 files changed

+23
-0
lines changed

src/StaticArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ include("svd.jl")
114114
include("lu.jl")
115115
include("qr.jl")
116116
include("deque.jl")
117+
include("flatten.jl")
117118
include("io.jl")
118119

119120
include("FixedSizeArrays.jl")

src/flatten.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Special case flatten of iterators of static arrays.
2+
import Base.Iterators: flatten_iteratorsize, flatten_length
3+
flatten_iteratorsize(::Union{Base.HasShape, Base.HasLength}, ::Type{<:StaticArray{S}}) where {S} = Base.HasLength()
4+
function flatten_length(f, T::Type{<:StaticArray{S}}) where {S}
5+
length(T)*length(f.it)
6+
end

test/flatten.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using StaticArrays, Test
2+
3+
@testset "Iterators.flatten" begin
4+
for x in [SVector(1.0, 2.0), MVector(1.0, 2.0),
5+
@SMatrix([1.0 2.0; 3.0 4.0]), @MMatrix([1.0 2.0]),
6+
Size(1,2)([1.0 2.0])
7+
]
8+
X = [x,x,x]
9+
@test length(Iterators.flatten(X)) == length(X)*length(x)
10+
@test collect(Iterators.flatten(typeof(x)[])) == []
11+
@test collect(Iterators.flatten(X)) == [x..., x..., x...]
12+
end
13+
@test collect(Iterators.flatten([SVector(1,1), SVector(1)])) == [1,1,1]
14+
@test_throws ArgumentError length(Iterators.flatten([SVector(1,1), SVector(1)]))
15+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ include("lu.jl")
4040
Random.seed!(42); include("qr.jl")
4141
Random.seed!(42); include("chol.jl") # hermitian_type(::Type{Any}) for block algorithm
4242
include("deque.jl")
43+
include("flatten.jl")
4344
include("io.jl")
4445
include("svd.jl")
4546
Random.seed!(42); include("fixed_size_arrays.jl")

0 commit comments

Comments
 (0)