Skip to content

Commit 22b9812

Browse files
author
Chris Foster
committed
Merge branch mixed-sarray-array-mapping
2 parents 9a13806 + 8d3ec32 commit 22b9812

File tree

4 files changed

+43
-26
lines changed

4 files changed

+43
-26
lines changed

src/mapreduce.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,16 @@
1-
# Returns the common Size of the inputs (or else throws a DimensionMismatch)
2-
@inline same_size(a1::StaticArray, as::StaticArray...) = _same_size(Size(a1), as...)
3-
@inline _same_size(s::Size) = s
4-
@inline function _same_size(s::Size, a1::StaticArray, as::StaticArray...)
5-
if s === Size(a1)
6-
return _same_size(s, as...)
7-
else
8-
throw(DimensionMismatch("Dimensions must match. Got inputs with $s and $(Size(a1))."))
9-
end
10-
end
11-
121
@inline _first(a1, as...) = a1
132

143
################
154
## map / map! ##
165
################
176

18-
@inline function map(f, a::StaticArray, b::StaticArray...)
19-
_map(f, same_size(a, b...), a, b...)
7+
# The following type signature for map() matches any list of AbstractArrays,
8+
# provided at least one is a static array.
9+
@inline function map(f, as::Union{SA,AbstractArray}...) where {SA<:StaticArray}
10+
_map(f, same_size(as...), as...)
2011
end
2112

22-
@generated function _map(f, ::Size{S}, a::StaticArray...) where {S}
13+
@generated function _map(f, ::Size{S}, a::AbstractArray...) where {S}
2314
exprs = Vector{Expr}(prod(S))
2415
for i 1:prod(S)
2516
tmp = [:(a[$j][$i]) for j 1:length(a)]
@@ -29,7 +20,7 @@ end
2920
newT = :(Core.Inference.return_type(f, Tuple{$(eltypes...)}))
3021
return quote
3122
@_inline_meta
32-
@inbounds return similar_type(typeof(_first(a...)), $newT)(tuple($(exprs...)))
23+
@inbounds return similar_type(typeof(_first(a...)), $newT, Size(S))(tuple($(exprs...)))
3324
end
3425
end
3526

src/traits.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ Length(::Type{SA}) where {SA <: StaticArray} = Length(Size(SA))
8282
@pure length(::Size{S}) where {S} = length(S)
8383
@pure length_val{S}(::Size{S}) = Val{length(S)}
8484

85-
@pure Base.:(==){S}(::Size{S}, s::Tuple{Vararg{Int}}) = S == s
86-
@pure Base.:(==){S}(s::Tuple{Vararg{Int}}, ::Size{S}) = s == S
85+
# Note - using === here, as Base doesn't inline == for tuples as of julia-0.6
86+
@pure Base.:(==){S}(::Size{S}, s::Tuple{Vararg{Int}}) = S === s
87+
@pure Base.:(==){S}(s::Tuple{Vararg{Int}}, ::Size{S}) = s === S
8788

88-
@pure Base.:(!=){S}(::Size{S}, s::Tuple{Vararg{Int}}) = S != s
89-
@pure Base.:(!=){S}(s::Tuple{Vararg{Int}}, ::Size{S}) = s != S
89+
@pure Base.:(!=){S}(::Size{S}, s::Tuple{Vararg{Int}}) = S !== s
90+
@pure Base.:(!=){S}(s::Tuple{Vararg{Int}}, ::Size{S}) = s !== S
9091

9192
@pure Base.prod{S}(::Size{S}) = prod(S)
9293

@@ -104,3 +105,29 @@ Length(::Type{SA}) where {SA <: StaticArray} = Length(Size(SA))
104105

105106
# The generated functions work with length, etc...
106107
@propagate_inbounds unroll_tuple(f, ::Length{L}) where {L} = unroll_tuple(f, Val{L})
108+
109+
110+
"""
111+
Return either the statically known Size() or runtime size()
112+
"""
113+
@inline _size(a) = size(a)
114+
@inline _size(a::StaticArray) = Size(a)
115+
116+
# Return static array from a set of arrays
117+
@inline _first_static(a1::StaticArray, as...) = a1
118+
@inline _first_static(a1, as...) = _first_static(as...)
119+
@inline _first_static() = throw(ArgumentError("No StaticArray found in argument list"))
120+
121+
"""
122+
Returns the common Size of the inputs (or else throws a DimensionMismatch)
123+
"""
124+
@inline function same_size(as...)
125+
s = Size(_first_static(as...))
126+
_sizes_match(s, as...) || _throw_size_mismatch(as...)
127+
s
128+
end
129+
@inline _sizes_match(s::Size, a1, as...) = ((s == _size(a1)) ? _sizes_match(s, as...) : false)
130+
@inline _sizes_match(s::Size) = true
131+
@noinline function _throw_size_mismatch(as...)
132+
throw(DimensionMismatch("Sizes $(map(_size, as)) of input arrays do not match"))
133+
end

test/linalg.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
v3 = [2,4,6,8]
1818
v4 = [4,3,2,1]
1919

20-
# We broke "inferrable" sizes of AbstractVectors for vector+vector, matrix*vector, etc...
21-
@test_broken @inferred(v1 + v4) === @SVector [6, 7, 8, 9]
22-
@test_broken @inferred(v3 + v2) === @SVector [6, 7, 8, 9]
23-
@test_broken @inferred(v1 - v4) === @SVector [-2, 1, 4, 7]
24-
@test_broken @inferred(v3 - v2) === @SVector [-2, 1, 4, 7]
20+
@test @inferred(v1 + v4) === @SVector [6, 7, 8, 9]
21+
@test @inferred(v3 + v2) === @SVector [6, 7, 8, 9]
22+
@test @inferred(v1 - v4) === @SVector [-2, 1, 4, 7]
23+
@test @inferred(v3 - v2) === @SVector [-2, 1, 4, 7]
2524
end
2625

2726
@testset "Interaction with `UniformScaling`" begin

test/mapreduce.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
@test @inferred(map(-, v1)) === @SVector [-2, -4, -6, -8]
1111
@test @inferred(map(+, v1, v2)) === @SVector [6, 7, 8, 9]
12-
@test_broken @inferred(map(+, normal_v1, v2)) === @SVector [6, 7, 8, 9]
13-
@test_broken @inferred(map(+, v1, normal_v2)) === @SVector [6, 7, 8, 9]
12+
@test @inferred(map(+, normal_v1, v2)) === @SVector [6, 7, 8, 9]
13+
@test @inferred(map(+, v1, normal_v2)) === @SVector [6, 7, 8, 9]
1414

1515
map!(+, mv, v1, v2)
1616
@test mv == @MVector [6, 7, 8, 9]

0 commit comments

Comments
 (0)