Skip to content

Commit e8bdcca

Browse files
committed
[WIP] faster reductions
1 parent 935dc85 commit e8bdcca

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/mapreduce.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ end
161161
## reduce ##
162162
############
163163

164-
@inline reduce(op, a::StaticArray; kw...) = mapreduce(identity, op, a; kw...)
164+
@inline reduce(op, a::StaticArray; dims=:, kw...) = _reduce(op, a, dims, kw.data)
165+
166+
@inline _reduce(op, a::StaticArray, dims=:, kw::NamedTuple=NamedTuple()) = _mapreduce(identity, op, dims, kw, Size(a), a)
165167

166168
#######################
167169
## related functions ##
@@ -186,9 +188,9 @@ end
186188
# TODO: change to use Base.reduce_empty/Base.reduce_first
187189
@inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)
188190

189-
@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(+, a; dims=dims)
190-
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims)
191-
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, +, a; dims=dims) # avoid ambiguity
191+
@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(+, a, dims)
192+
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a)
193+
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) # avoid ambiguity
192194

193195
@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(*, a; dims=dims)
194196
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)

0 commit comments

Comments
 (0)