Skip to content

Commit 16fa766

Browse files
Chris Fosterandyferris
authored andcommitted
Refactor to generalize map() and improve codegen for same_size()
1 parent 64ec1a2 commit 16fa766

File tree

2 files changed

+20
-30
lines changed

2 files changed

+20
-30
lines changed

src/mapreduce.jl

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,10 @@
1-
@inline _first(a1, as...) = a1
2-
31
################
42
## map / map! ##
53
################
64

7-
@inline map(f, as::StaticArray...) =
5+
@inline function map(f, as::Union{SA,AbstractArray}...) where {SA<:StaticArray}
86
_map(f, same_size(as...), as...)
9-
# Mixed StaticArray + AbstractArray; various versions to avoid ambiguities.
10-
# With the versions below, if a StaticArray isn't present in the first two
11-
# arguments, we'll end up in Base.map() instead.
12-
@inline map(f, a1::StaticArray, as::AbstractArray...) =
13-
_map(f, same_size(a1, as...), a1, as...)
14-
@inline map(f, a1::AbstractArray, a2::StaticArray, as::AbstractArray...) =
15-
_map(f, same_size(a1, a2, as...), a1, a2, as...)
16-
@inline map(f, a1::StaticArray, a2::AbstractArray, as::AbstractArray...) =
17-
_map(f, same_size(a1, a2, as...), a1, a2, as...)
18-
@inline map(f, a1::StaticArray, a2::StaticArray, as::AbstractArray...) =
19-
_map(f, same_size(a1, a2, as...), a1, a2, as...)
7+
end
208

219
@generated function _map(f, ::Size{S}, a::AbstractArray...) where {S}
2210
exprs = Vector{Expr}(prod(S))
@@ -28,7 +16,7 @@
2816
newT = :(Core.Inference.return_type(f, Tuple{$(eltypes...)}))
2917
return quote
3018
@_inline_meta
31-
@inbounds return similar_type(typeof(_first(a...)), $newT, Size(S))(tuple($(exprs...)))
19+
@inbounds return similar_type(typeof(_first_static(a...)), $newT, Size(S))(tuple($(exprs...)))
3220
end
3321
end
3422

src/traits.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,21 @@ Return either the statically known Size() or runtime size()
113113
@inline _size(a) = size(a)
114114
@inline _size(a::StaticArray) = Size(a)
115115

116-
# Return first static Size from a set of arrays
117-
@inline _first_static_size(a1::StaticArray, as...) = Size(a1)
118-
@inline _first_static_size(a1, as...) = _first_static_size(as...)
119-
@inline _first_static_size() = throw(ArgumentError("No StaticArray found in argument list"))
120-
121-
# Returns the common Size of the inputs (or else throws a DimensionMismatch)
122-
@inline same_size(as...) = _same_size(_first_static_size(as...), as...)
123-
@inline function _same_size(s::Size, a1, as...)
124-
if s == _size(a1)
125-
return _same_size(s, as...)
126-
else
127-
throw(DimensionMismatch("Dimensions must match. Got inputs with $s and $(_size(a1))."))
128-
end
129-
end
130-
@inline _same_size(s::Size) = s
116+
# Return static array from a set of arrays
117+
@inline _first_static(a1::StaticArray, as...) = a1
118+
@inline _first_static(a1, as...) = _first_static(as...)
119+
@inline _first_static() = throw(ArgumentError("No StaticArray found in argument list"))
131120

121+
"""
122+
Returns the common Size of the inputs (or else throws a DimensionMismatch)
123+
"""
124+
@inline function same_size(as...)
125+
s = Size(_first_static(as...))
126+
_sizes_match(s, as...) || _throw_size_mismatch(as...)
127+
s
128+
end
129+
@inline _sizes_match(s::Size, a1, as...) = ((s == _size(a1)) ? _sizes_match(s, as...) : false)
130+
@inline _sizes_match(s::Size) = true
131+
@noinline function _throw_size_mismatch(as...)
132+
throw(DimensionMismatch("Sizes $(map(_size, as)) of input arrays do not match"))
133+
end

0 commit comments

Comments
 (0)