From ffc962b9eed144c23963d62e6e512d5ac98532e8 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Fri, 28 Feb 2020 23:33:28 -0800 Subject: [PATCH] Add divide-and-conquer _mapreduce --- src/mapreduce.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 30c09c3a..6d612ead 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -112,6 +112,28 @@ end @inline _mapreduce(args::Vararg{Any,N}) where N = _mapfoldl(args...) +@generated function _mapreduce(f, op, dims::Colon, init, ::Size{S}, a::StaticArray...) where {S} + function op_expr(idx) + if length(idx) == 1 + i, = idx + tmp = [:(a[$j][$i]) for j ∈ 1:length(a)] + expr = :(f($(tmp...))) + if init === _InitialValue + return :(Base.reduce_first(op, $expr)) + else + return :(op(init, $expr)) + end + end + left = op_expr(idx[1:end÷2]) + right = op_expr(idx[end÷2+1:end]) + return :(op($left, $right)) + end + return quote + @_inline_meta + @inbounds return $(op_expr(1:prod(S))) + end +end + @generated function _mapfoldl(f, op, dims::Colon, init, ::Size{S}, a::StaticArray...) where {S} tmp = [:(a[$j][1]) for j ∈ 1:length(a)] expr = :(f($(tmp...)))