Skip to content

Commit 48cd0e4

Browse files
Improve inferability of resursive map on nested StaticArrays (#594)
Remove the `::Size` argument to `_map`, instead determining the size internally.
1 parent bb82cd5 commit 48cd0e4

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

src/mapreduce.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,27 @@
88
# `map(f, as::Union{<:StaticArray,AbstractArray}...)` which included at least one `StaticArray`
99
# this is not the case on 0.7 and we instead hope to find a StaticArray in the first two arguments.
1010
@inline function map(f, a1::StaticArray, as::AbstractArray...)
11-
_map(f, same_size(a1, as...), a1, as...)
11+
_map(f, a1, as...)
1212
end
1313
@inline function map(f, a1::AbstractArray, a2::StaticArray, as::AbstractArray...)
14-
_map(f, same_size(a1, a2, as...), a1, a2, as...)
14+
_map(f, a1, a2, as...)
1515
end
1616
@inline function map(f, a1::StaticArray, a2::StaticArray, as::AbstractArray...)
17-
_map(f, same_size(a1, a2, as...), a1, a2, as...)
17+
_map(f, a1, a2, as...)
1818
end
1919

20-
@generated function _map(f, ::Size{S}, a::AbstractArray...) where {S}
20+
@generated function _map(f, a::AbstractArray...)
21+
i = findfirst(ai -> ai <: StaticArray, a)
22+
if i === nothing
23+
return :(throw(ArgumentError("No StaticArray found in argument list")))
24+
end
25+
# Passing the Size as an argument to _map leads to inference issues when
26+
# recursively mapping over nested StaticArrays (see issue #593). Calling
27+
# Size in the generator here is valid because a[i] is known to be a
28+
# StaticArray for which the default Size method is correct. If wrapped
29+
# StaticArrays (with a custom Size method) are to be supported, this will
30+
# no longer be valid.
31+
S = Size(a[i])
2132
exprs = Vector{Expr}(undef, prod(S))
2233
for i 1:prod(S)
2334
tmp = [:(a[$j][$i]) for j 1:length(a)]
@@ -26,8 +37,9 @@ end
2637

2738
return quote
2839
@_inline_meta
40+
S = same_size(a...)
2941
@inbounds elements = tuple($(exprs...))
30-
@inbounds return similar_type(typeof(_first(a...)), eltype(elements), Size(S))(elements)
42+
@inbounds return similar_type(typeof(_first(a...)), eltype(elements), S)(elements)
3143
end
3244
end
3345

test/mapreduce.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,11 @@ end
149149

150150
@test ((@SVector Int64[]) + (@SVector Int64[])) == (@SVector Int64[])
151151
end
152+
@testset "Nested SVectors" begin
153+
# issue #593
154+
v = SVector(SVector(3, 2), SVector(5, 7))
155+
@test @inferred(v + v) == SVector(SVector(6, 4), SVector(10, 14))
156+
v = SVector(SVector(3, 2, 1), SVector(5, 7, 9))
157+
@test @inferred(v + v) == SVector(SVector(6, 4, 2), SVector(10, 14, 18))
158+
end
152159
end

0 commit comments

Comments
 (0)