Skip to content

Commit 96c92a6

Browse files
c42fandyferris
authored andcommitted
Make map() take mixtures of StaticArray and AbstractArray
This allows StaticArray/AbstractArray linear algebra + and - to return a static array.
1 parent 9a13806 commit 96c92a6

File tree

4 files changed

+56
-28
lines changed

4 files changed

+56
-28
lines changed

src/mapreduce.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
1-
# Returns the common Size of the inputs (or else throws a DimensionMismatch)
2-
@inline same_size(a1::StaticArray, as::StaticArray...) = _same_size(Size(a1), as...)
3-
@inline _same_size(s::Size) = s
4-
@inline function _same_size(s::Size, a1::StaticArray, as::StaticArray...)
5-
if s === Size(a1)
6-
return _same_size(s, as...)
7-
else
8-
throw(DimensionMismatch("Dimensions must match. Got inputs with $s and $(Size(a1))."))
9-
end
10-
end
11-
121
@inline _first(a1, as...) = a1
132

143
################
154
## map / map! ##
165
################
176

18-
@inline function map(f, a::StaticArray, b::StaticArray...)
19-
_map(f, same_size(a, b...), a, b...)
20-
end
21-
22-
@generated function _map(f, ::Size{S}, a::StaticArray...) where {S}
7+
@inline map(f, as::StaticArray...) =
8+
_map(f, same_size(as...), as...)
9+
# Mixed StaticArray + AbstractArray; various versions to avoid ambiguities.
10+
# With the versions below, if a StaticArray isn't present in the first two
11+
# arguments, we'll end up in Base.map() instead.
12+
@inline map(f, a1::StaticArray, as::AbstractArray...) =
13+
_map(f, same_size(a1, as...), a1, as...)
14+
@inline map(f, a1::AbstractArray, a2::StaticArray, as::AbstractArray...) =
15+
_map(f, same_size(a1, a2, as...), a1, a2, as...)
16+
@inline map(f, a1::StaticArray, a2::AbstractArray, as::AbstractArray...) =
17+
_map(f, same_size(a1, a2, as...), a1, a2, as...)
18+
@inline map(f, a1::StaticArray, a2::StaticArray, as::AbstractArray...) =
19+
_map(f, same_size(a1, a2, as...), a1, a2, as...)
20+
21+
@generated function _map(f, ::Size{S}, a::AbstractArray...) where {S}
2322
exprs = Vector{Expr}(prod(S))
2423
for i 1:prod(S)
2524
tmp = [:(a[$j][$i]) for j 1:length(a)]
@@ -29,7 +28,7 @@ end
2928
newT = :(Core.Inference.return_type(f, Tuple{$(eltypes...)}))
3029
return quote
3130
@_inline_meta
32-
@inbounds return similar_type(typeof(_first(a...)), $newT)(tuple($(exprs...)))
31+
@inbounds return similar_type(typeof(_first(a...)), $newT, Size(S))(tuple($(exprs...)))
3332
end
3433
end
3534

src/traits.jl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ Length(::Type{SA}) where {SA <: StaticArray} = Length(Size(SA))
8282
@pure length(::Size{S}) where {S} = length(S)
8383
@pure length_val{S}(::Size{S}) = Val{length(S)}
8484

85-
@pure Base.:(==){S}(::Size{S}, s::Tuple{Vararg{Int}}) = S == s
86-
@pure Base.:(==){S}(s::Tuple{Vararg{Int}}, ::Size{S}) = s == S
85+
# Note - using === here, as Base doesn't inline == for tuples as of julia-0.6
86+
@pure Base.:(==){S}(::Size{S}, s::Tuple{Vararg{Int}}) = S === s
87+
@pure Base.:(==){S}(s::Tuple{Vararg{Int}}, ::Size{S}) = s === S
8788

88-
@pure Base.:(!=){S}(::Size{S}, s::Tuple{Vararg{Int}}) = S != s
89-
@pure Base.:(!=){S}(s::Tuple{Vararg{Int}}, ::Size{S}) = s != S
89+
@pure Base.:(!=){S}(::Size{S}, s::Tuple{Vararg{Int}}) = S !== s
90+
@pure Base.:(!=){S}(s::Tuple{Vararg{Int}}, ::Size{S}) = s !== S
9091

9192
@pure Base.prod{S}(::Size{S}) = prod(S)
9293

@@ -104,3 +105,32 @@ Length(::Type{SA}) where {SA <: StaticArray} = Length(Size(SA))
104105

105106
# The generated functions work with length, etc...
106107
@propagate_inbounds unroll_tuple(f, ::Length{L}) where {L} = unroll_tuple(f, Val{L})
108+
109+
110+
"""
111+
Static or runtime size of an array
112+
"""
113+
const SRSize = Union{Size,Tuple{Vararg{Int}}}
114+
115+
"""
116+
Return either the statically known Size() or runtime size()
117+
"""
118+
@inline _size(a) = size(a)
119+
@inline _size(a::StaticArray) = Size(a)
120+
121+
# Return first static Size from a set of arrays
122+
@inline _first_static_size(a1::StaticArray, as...) = Size(a1)
123+
@inline _first_static_size(a1, as...) = _first_static_size(as...)
124+
@inline _first_static_size() = throw(ArgumentError("No StaticArray found in argument list"))
125+
126+
# Returns the common Size of the inputs (or else throws a DimensionMismatch)
127+
@inline same_size(as...) = _same_size(_first_static_size(as...), as...)
128+
@inline function _same_size(s::SRSize, a1, as...)
129+
if s == _size(a1)
130+
return _same_size(s, as...)
131+
else
132+
throw(DimensionMismatch("Dimensions must match. Got inputs with $s and $(_size(a1))."))
133+
end
134+
end
135+
@inline _same_size(s::SRSize) = s
136+

test/linalg.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
v3 = [2,4,6,8]
1818
v4 = [4,3,2,1]
1919

20-
# We broke "inferrable" sizes of AbstractVectors for vector+vector, matrix*vector, etc...
21-
@test_broken @inferred(v1 + v4) === @SVector [6, 7, 8, 9]
22-
@test_broken @inferred(v3 + v2) === @SVector [6, 7, 8, 9]
23-
@test_broken @inferred(v1 - v4) === @SVector [-2, 1, 4, 7]
24-
@test_broken @inferred(v3 - v2) === @SVector [-2, 1, 4, 7]
20+
@test @inferred(v1 + v4) === @SVector [6, 7, 8, 9]
21+
@test @inferred(v3 + v2) === @SVector [6, 7, 8, 9]
22+
@test @inferred(v1 - v4) === @SVector [-2, 1, 4, 7]
23+
@test @inferred(v3 - v2) === @SVector [-2, 1, 4, 7]
2524
end
2625

2726
@testset "Interaction with `UniformScaling`" begin

test/mapreduce.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
@test @inferred(map(-, v1)) === @SVector [-2, -4, -6, -8]
1111
@test @inferred(map(+, v1, v2)) === @SVector [6, 7, 8, 9]
12-
@test_broken @inferred(map(+, normal_v1, v2)) === @SVector [6, 7, 8, 9]
13-
@test_broken @inferred(map(+, v1, normal_v2)) === @SVector [6, 7, 8, 9]
12+
@test @inferred(map(+, normal_v1, v2)) === @SVector [6, 7, 8, 9]
13+
@test @inferred(map(+, v1, normal_v2)) === @SVector [6, 7, 8, 9]
1414

1515
map!(+, mv, v1, v2)
1616
@test mv == @MVector [6, 7, 8, 9]

0 commit comments

Comments
 (0)