Skip to content

Commit 2cd5133

Browse files
authored
Merge pull request #210 from fredrikekre/fe/inline-sum
inline sum(f::Callable, a::StaticArray)
2 parents f7e6eb6 + c5401c1 commit 2cd5133

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/mapreduce.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ end
169169

170170
# These are all similar in Base but not @inline'd
171171
@inline sum(a::StaticArray{<:Any, T}) where {T} = reduce(+, zero(T), a)
172+
@inline sum(f::Base.Callable, a::StaticArray) = mapreduce(f, +, a)
172173
@inline prod(a::StaticArray{<:Any, T}) where {T} = reduce(*, one(T), a)
173174
@inline count(a::StaticArray{<:Any, Bool}) = reduce(+, 0, a)
174175
@inline all(a::StaticArray{<:Any, Bool}) = reduce(&, true, a) # non-branching versions

test/mapreduce.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
@test reduce(+, v1) === 20
2929
@test reduce(+, 0, v1) === 20
3030
@test sum(v1) === 20
31+
@test sum(abs2, v1) === 120
3132
@test prod(v1) === 384
3233
@test mean(v1) === 5.
3334
@test maximum(v1) === 8
@@ -52,7 +53,7 @@
5253
a = @SArray rand(4,3) # as of Julia v0.5, diff() for regular Array is defined only for vectors and matrices
5354
@test diff(a) == diff(a, Val{1}) == diff(a, 1)
5455
@test diff(a, Val{2}) == diff(a, 2)
55-
56+
5657
@test reducedim(max, a, Val{1}, -1.) == reducedim(max, a, 1, -1.)
5758
@test reducedim(max, a, Val{2}, -1.) == reducedim(max, a, 2, -1.)
5859
end
@@ -83,7 +84,7 @@
8384
@test @inferred(broadcast(+, v1, c)) === @SVector [4, 6, 8, 10]
8485
@test @inferred(broadcast(+, v1, v2)) === map(+, v1, v2)
8586
@test @inferred(broadcast(+, v1, M)) === @SMatrix [3 4; 7 8; 11 12; 15 16]
86-
87+
8788
@test_throws DimensionMismatch broadcast!(-, MVector{5, Int}(), v1)
8889

8990
broadcast!(-, mv, v1)
@@ -96,7 +97,7 @@
9697
@test mm == @MMatrix [3 4; 7 8; 11 12; 15 16]
9798
# issue #103
9899
@test map(+, M, M) == [2 4; 6 8; 10 12; 14 16]
99-
100+
100101
@test ((@SVector Int64[]) + (@SVector Int64[])) == (@SVector Int64[])
101102
end
102103
end

0 commit comments

Comments
 (0)