Skip to content

Commit 28e0482

Browse files
authored
Mapping over enumeration of a static array (#1107)
* Mapping over enumeration of a static array * bump version * use separate `enumerate_static`
1 parent bbf3a74 commit 28e0482

File tree

4 files changed

+44
-1
lines changed

4 files changed

+44
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.5.9"
3+
version = "1.5.10"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/StaticArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ export @MVector, @MMatrix, @MArray
5858

5959
export similar_type
6060
export push, pop, pushfirst, popfirst, insert, deleteat, setindex
61+
export enumerate_static
6162

6263
export StaticArraysCore
6364

src/mapreduce.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,37 @@ end
7777
end
7878
end
7979

80+
struct StaticEnumerate{TA}
81+
itr::TA
82+
end
83+
84+
enumerate_static(a::StaticArray) = StaticEnumerate(a)
85+
86+
@generated function map(f, a::StaticEnumerate{<:StaticArray})
87+
S = Size(a.parameters[1])
88+
if prod(S) == 0
89+
# In the empty case only, use inference to try figuring out a sensible
90+
# eltype, as is done in Base.collect and broadcast.
91+
# See https://github.com/JuliaArrays/StaticArrays.jl/issues/528
92+
return quote
93+
@_inline_meta
94+
T = Core.Compiler.return_type(f, Tuple{Tuple{Int,$(eltype(a.parameters[1]))}})
95+
@inbounds return similar_type(a.itr, T, $S)()
96+
end
97+
end
98+
99+
exprs = Vector{Expr}(undef, prod(S))
100+
for i 1:prod(S)
101+
exprs[i] = :(f(($i, a.itr[$i])))
102+
end
103+
104+
return quote
105+
@_inline_meta
106+
@inbounds elements = tuple($(exprs...))
107+
@inbounds return similar_type(typeof(a.itr), eltype(elements), $S)(elements)
108+
end
109+
end
110+
80111
@inline function map!(f, dest::StaticArray, a::StaticArray...)
81112
_map!(f, dest, same_size(dest, a...), a...)
82113
end

test/mapreduce.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,4 +240,15 @@ using Statistics: mean
240240
@test @inferred(reduce(vcat, v2)) === @SVector [1,2,3,4]
241241
@test @inferred(reduce(hcat, v2)) === @SMatrix [1 3; 2 4]
242242
end
243+
@testset "map over enumerate" begin
244+
# issue #1106
245+
v = @SVector [1, -2, 3, -4]
246+
m = @SMatrix [1 -2; 3 -4]
247+
v0 = SVector{0,Float64}()
248+
m0 = SMatrix{0,0,Float64}()
249+
@test @inferred(map(f -> f[1] * f[2], enumerate_static(v))) === @SVector [1, -4, 9, -16]
250+
@test @inferred(map(f -> f[1] * f[2], enumerate_static(m))) === @SMatrix [1 -6; 6 -16]
251+
@test @inferred(map(f -> f, enumerate_static(v0))) === SVector{0,Tuple{Int,Float64}}()
252+
@test @inferred(map(f -> f, enumerate_static(m0))) === SMatrix{0,0,Tuple{Int,Float64}}()
253+
end
243254
end

0 commit comments

Comments
 (0)