Skip to content

Commit 3d5cfec

Browse files
committed
move to 8x bins for commutative ops
1 parent 264b8db commit 3d5cfec

File tree

1 file changed

+48
-28
lines changed

1 file changed

+48
-28
lines changed

base/multidimensional.jl

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,21 +1113,21 @@ function mapreduce_kernel_commutative(f, op, A, init, inds::AbstractArray)
11131113
return _mapreduce_kernel_commutative(f, op, A, init, inds)
11141114
end
11151115

1116-
# This special internal method must have at least 4 indices and allows passing
1116+
# This special internal method must have at least 8 indices and allows passing
11171117
# optional scalar leading and trailing dimensions
11181118
function _mapreduce_kernel_commutative(f, op, A, init, inds, leading=(), trailing=())
11191119
i1, iN = firstindex(inds), lastindex(inds)
11201120
n = length(inds)
1121-
@nexprs 4 N->a_N = @inbounds A[leading..., inds[i1+(N-1)], trailing...]
1122-
@nexprs 4 N->v_N = _mapreduce_start(f, op, A, init, a_N)
1123-
for batch in 1:(n>>2)-1
1124-
i = i1 + batch*4
1125-
@nexprs 4 N->a_N = @inbounds A[leading..., inds[i+(N-1)], trailing...]
1126-
@nexprs 4 N->fa_N = f(a_N)
1127-
@nexprs 4 N->v_N = op(v_N, fa_N)
1128-
end
1129-
v = op(op(v_1, v_2), op(v_3, v_4))
1130-
i = i1 + (n>>2)*4 - 1
1121+
@nexprs 8 N->a_N = @inbounds A[leading..., inds[i1+(N-1)], trailing...]
1122+
@nexprs 8 N->v_N = _mapreduce_start(f, op, A, init, a_N)
1123+
for batch in 1:(n>>3)-1
1124+
i = i1 + batch*8
1125+
@nexprs 8 N->a_N = @inbounds A[leading..., inds[i+(N-1)], trailing...]
1126+
@nexprs 8 N->fa_N = f(a_N)
1127+
@nexprs 8 N->v_N = op(v_N, fa_N)
1128+
end
1129+
v = op(op(op(v_1, v_2), op(v_3, v_4)), op(op(v_5, v_6), op(v_7, v_8)))
1130+
i = i1 + (n>>3)*8 - 1
11311131
i == iN && return v
11321132
for i in i+1:iN
11331133
ai = @inbounds A[leading..., inds[i], trailing...]
@@ -1141,7 +1141,7 @@ function mapreduce_kernel_commutative(f::F, op::G, A, init, inds::CartesianIndic
11411141
N == 1 && return mapreduce_kernel_commutative(f, op, A, init, inds.indices[1])
11421142
is = inds.indices[1]
11431143
js = inds.indices[2]
1144-
if length(is) == 1 && length(js) >= 4
1144+
if length(is) == 1 && length(js) >= 8
11451145
# It's quite useful to optimize this case for dimensional reductions
11461146
i = only(is)
11471147
outer = CartesianIndices(tail(tail(inds.indices)))
@@ -1151,7 +1151,7 @@ function mapreduce_kernel_commutative(f::F, op::G, A, init, inds::CartesianIndic
11511151
v = op(v, _mapreduce_kernel_commutative(f, op, A, init, js, (i,), o.I))
11521152
end
11531153
return v
1154-
elseif length(is) < 4 # TODO: tune this number
1154+
elseif length(is) < 8 # TODO: tune this number
11551155
# These small cases could be further optimized
11561156
return mapreduce_kernel_commutative(i->f(A[i]), op, inds, init, HasShape{N}(), length(inds))[1]
11571157
else
@@ -1180,23 +1180,23 @@ function mapreduce_kernel_commutative(f, op, itr, init, ::Union{HasLength, HasSh
11801180
end
11811181
return v_1, s
11821182
end
1183-
@nexprs 3 n->begin
1183+
@nexprs 7 n->begin
11841184
it = iterate(itr, s)
11851185
it === nothing && _throw_iterator_assertion_error()
11861186
a, s = it
11871187
v_{n+1} = _mapreduce_start(f, op, itr, init, a)
11881188
end
1189-
i = 4
1190-
for outer i in 8:4:n
1191-
@nexprs 4 n->begin
1189+
i = 8
1190+
for outer i in 16:8:n
1191+
@nexprs 8 n->begin
11921192
it = iterate(itr, s)
11931193
it === nothing && _throw_iterator_assertion_error()
11941194
a_n, s = it
11951195
end
1196-
@nexprs 4 n-> fa_n = f(a_n)
1197-
@nexprs 4 n-> v_n = op(v_n, fa_n)
1196+
@nexprs 8 n-> fa_n = f(a_n)
1197+
@nexprs 8 n-> v_n = op(v_n, fa_n)
11981198
end
1199-
v = op(op(v_1, v_2), op(v_3, v_4))
1199+
v = op(op(op(v_1, v_2), op(v_3, v_4)), op(op(v_5, v_6), op(v_7, v_8)))
12001200
for _ in i+1:n
12011201
it = iterate(itr, s)
12021202
it === nothing && _throw_iterator_assertion_error()
@@ -1222,23 +1222,43 @@ function mapreduce_kernel_commutative(f, op, itr, init, ::IteratorSize, n, state
12221222
it === nothing && return Some(op(op(v_1, v_2), v_3))
12231223
a, s = it
12241224
v_4 = _mapreduce_start(f, op, itr, init, a)
1225-
for _ in 2:n>>2
1226-
@nexprs 4 N->begin
1225+
it = iterate(itr, s)
1226+
it === nothing && return Some(op(op(v_1, v_2), op(v_3, v_4)))
1227+
a, s = it
1228+
v_5 = _mapreduce_start(f, op, itr, init, a)
1229+
it = iterate(itr, s)
1230+
it === nothing && return Some(op(op(op(v_1, v_2), op(v_3, v_4)), v_5))
1231+
a, s = it
1232+
v_6 = _mapreduce_start(f, op, itr, init, a)
1233+
it = iterate(itr, s)
1234+
it === nothing && return Some(op(op(op(v_1, v_2), op(v_3, v_4)), op(v_5, v_6)))
1235+
a, s = it
1236+
v_7 = _mapreduce_start(f, op, itr, init, a)
1237+
it = iterate(itr, s)
1238+
it === nothing && return Some(op(op(op(v_1, v_2), op(v_3, v_4)), op(op(v_5, v_6), v_7)))
1239+
a, s = it
1240+
v_8 = _mapreduce_start(f, op, itr, init, a)
1241+
for _ in 3:n>>3
1242+
@nexprs 8 N->begin
12271243
it = iterate(itr, s)
12281244
if it === nothing
1245+
N > 7 && (v_7 = op(v_7, f(a_7)))
1246+
N > 6 && (v_6 = op(v_6, f(a_6)))
1247+
N > 5 && (v_5 = op(v_5, f(a_5)))
1248+
N > 4 && (v_4 = op(v_4, f(a_4)))
12291249
N > 3 && (v_3 = op(v_3, f(a_3)))
12301250
N > 2 && (v_2 = op(v_2, f(a_2)))
12311251
N > 1 && (v_1 = op(v_1, f(a_1)))
1232-
return Some(op(op(v_1, v_2), op(v_3, v_4)))
1252+
return Some(op(op(op(v_1, v_2), op(v_3, v_4)), op(op(v_5, v_6), op(v_7, v_8))))
12331253
end
12341254
a_N, s = it
12351255
end
1236-
@nexprs 4 N->fa_N = f(a_N)
1237-
@nexprs 4 N->v_N = op(v_N, fa_N)
1256+
@nexprs 8 N->fa_N = f(a_N)
1257+
@nexprs 8 N->v_N = op(v_N, fa_N)
12381258
end
1239-
v = op(op(v_1, v_2), op(v_3, v_4))
1240-
i = (n>>2)*4
1241-
@nexprs 4 N->begin
1259+
v = op(op(op(v_1, v_2), op(v_3, v_4)), op(op(v_5, v_6), op(v_7, v_8)))
1260+
i = (n>>3)*8
1261+
@nexprs 8 N->begin
12421262
it = iterate(itr, s)
12431263
if it === nothing
12441264
return Some(v)

0 commit comments

Comments
 (0)