Skip to content

Commit c7b81d6

Browse files
committed
faster reductions (completed)
1 parent e8bdcca commit c7b81d6

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

src/mapreduce.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ end
163163

164164
@inline reduce(op, a::StaticArray; dims=:, kw...) = _reduce(op, a, dims, kw.data)
165165

166-
@inline _reduce(op, a::StaticArray, dims=:, kw::NamedTuple=NamedTuple()) = _mapreduce(identity, op, dims, kw, Size(a), a)
166+
@inline _reduce(op, a::StaticArray, dims, kw::NamedTuple=NamedTuple()) = _mapreduce(identity, op, dims, kw, Size(a), a)
167167

168168
#######################
169169
## related functions ##
@@ -192,34 +192,34 @@ end
192192
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a)
193193
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) # avoid ambiguity
194194

195-
@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = reduce(*, a; dims=dims)
196-
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
197-
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = mapreduce(f, *, a; dims=dims)
195+
@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(*, a, dims)
196+
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
197+
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
198198

199-
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(+, a; dims=dims)
200-
@inline count(f, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, +, a; dims=dims)
199+
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(+, a, dims)
200+
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, NamedTuple(), Size(a), a)
201201

202-
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(&, a; dims=dims, init=true) # non-branching versions
203-
@inline all(f::Function, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, &, a; dims=dims, init=true)
202+
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, (init=true,)) # non-branching versions
203+
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, (init=true,), Size(a), a)
204204

205-
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = reduce(|, a; dims=dims, init=false) # (benchmarking needed)
206-
@inline any(f::Function, a::StaticArray; dims=:) = mapreduce(x->f(x)::Bool, |, a; dims=dims, init=false) # (benchmarking needed)
205+
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, (init=false,)) # (benchmarking needed)
206+
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, (init=false,), Size(a), a) # (benchmarking needed)
207207

208-
@inline Base.in(x, a::StaticArray) = mapreduce(==(x), |, a, init=false)
208+
@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, (init=false,), Size(a), a)
209209

210210
_mean_denom(a, dims::Colon) = length(a)
211211
_mean_denom(a, dims::Int) = size(a, dims)
212212
_mean_denom(a, ::Val{D}) where {D} = size(a, D)
213213
_mean_denom(a, ::Type{Val{D}}) where {D} = size(a, D)
214214

215-
@inline mean(a::StaticArray; dims=:) = sum(a; dims=dims) / _mean_denom(a,dims)
216-
@inline mean(f::Function, a::StaticArray;dims=:) = sum(f, a; dims=dims) / _mean_denom(a,dims)
215+
@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims)
216+
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) / _mean_denom(a, dims)
217217

218-
@inline minimum(a::StaticArray; dims=:) = reduce(min, a; dims=dims) # base has mapreduce(idenity, scalarmin, a)
219-
@inline minimum(f::Function, a::StaticArray; dims=:) = mapreduce(f, min, a; dims=dims)
218+
@inline minimum(a::StaticArray; dims=:) = _reduce(min, a, dims) # base has mapreduce(idenity, scalarmin, a)
219+
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, NamedTuple(), Size(a), a)
220220

221-
@inline maximum(a::StaticArray; dims=:) = reduce(max, a; dims=dims) # base has mapreduce(idenity, scalarmax, a)
222-
@inline maximum(f::Function, a::StaticArray; dims=:) = mapreduce(f, max, a; dims=dims)
221+
@inline maximum(a::StaticArray; dims=:) = _reduce(max, a, dims) # base has mapreduce(idenity, scalarmax, a)
222+
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, NamedTuple(), Size(a), a)
223223

224224
# Diff is slightly different
225225
@inline diff(a::StaticArray; dims) = _diff(Size(a), a, dims)

0 commit comments

Comments
 (0)