Skip to content

Commit 00c4f51

Browse files
committed
move to 8x bins for commutative ops
1 parent 264b8db commit 00c4f51

File tree

1 file changed

+45
-25
lines changed

1 file changed

+45
-25
lines changed

base/multidimensional.jl

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,16 +1118,16 @@ end
11181118
function _mapreduce_kernel_commutative(f, op, A, init, inds, leading=(), trailing=())
11191119
i1, iN = firstindex(inds), lastindex(inds)
11201120
n = length(inds)
1121-
@nexprs 4 N->a_N = @inbounds A[leading..., inds[i1+(N-1)], trailing...]
1122-
@nexprs 4 N->v_N = _mapreduce_start(f, op, A, init, a_N)
1123-
for batch in 1:(n>>2)-1
1124-
i = i1 + batch*4
1125-
@nexprs 4 N->a_N = @inbounds A[leading..., inds[i+(N-1)], trailing...]
1126-
@nexprs 4 N->fa_N = f(a_N)
1127-
@nexprs 4 N->v_N = op(v_N, fa_N)
1128-
end
1129-
v = op(op(v_1, v_2), op(v_3, v_4))
1130-
i = i1 + (n>>2)*4 - 1
1121+
@nexprs 8 N->a_N = @inbounds A[leading..., inds[i1+(N-1)], trailing...]
1122+
@nexprs 8 N->v_N = _mapreduce_start(f, op, A, init, a_N)
1123+
for batch in 1:(n>>3)-1
1124+
i = i1 + batch*8
1125+
@nexprs 8 N->a_N = @inbounds A[leading..., inds[i+(N-1)], trailing...]
1126+
@nexprs 8 N->fa_N = f(a_N)
1127+
@nexprs 8 N->v_N = op(v_N, fa_N)
1128+
end
1129+
v = op(op(op(v_1, v_2), op(v_3, v_4)), op(op(v_5, v_6), op(v_7, v_8)))
1130+
i = i1 + (n>>3)*8 - 1
11311131
i == iN && return v
11321132
for i in i+1:iN
11331133
ai = @inbounds A[leading..., inds[i], trailing...]
@@ -1180,23 +1180,23 @@ function mapreduce_kernel_commutative(f, op, itr, init, ::Union{HasLength, HasSh
11801180
end
11811181
return v_1, s
11821182
end
1183-
@nexprs 3 n->begin
1183+
@nexprs 7 n->begin
11841184
it = iterate(itr, s)
11851185
it === nothing && _throw_iterator_assertion_error()
11861186
a, s = it
11871187
v_{n+1} = _mapreduce_start(f, op, itr, init, a)
11881188
end
1189-
i = 4
1190-
for outer i in 8:4:n
1191-
@nexprs 4 n->begin
1189+
i = 8
1190+
for outer i in 16:8:n
1191+
@nexprs 8 n->begin
11921192
it = iterate(itr, s)
11931193
it === nothing && _throw_iterator_assertion_error()
11941194
a_n, s = it
11951195
end
1196-
@nexprs 4 n-> fa_n = f(a_n)
1197-
@nexprs 4 n-> v_n = op(v_n, fa_n)
1196+
@nexprs 8 n-> fa_n = f(a_n)
1197+
@nexprs 8 n-> v_n = op(v_n, fa_n)
11981198
end
1199-
v = op(op(v_1, v_2), op(v_3, v_4))
1199+
v = op(op(op(v_1, v_2), op(v_3, v_4)), op(op(v_5, v_6), op(v_7, v_8)))
12001200
for _ in i+1:n
12011201
it = iterate(itr, s)
12021202
it === nothing && _throw_iterator_assertion_error()
@@ -1222,23 +1222,43 @@ function mapreduce_kernel_commutative(f, op, itr, init, ::IteratorSize, n, state
12221222
it === nothing && return Some(op(op(v_1, v_2), v_3))
12231223
a, s = it
12241224
v_4 = _mapreduce_start(f, op, itr, init, a)
1225-
for _ in 2:n>>2
1226-
@nexprs 4 N->begin
1225+
it = iterate(itr, s)
1226+
it === nothing && return Some(op(op(v_1, v_2), op(v_3, v_4)))
1227+
a, s = it
1228+
v_5 = _mapreduce_start(f, op, itr, init, a)
1229+
it = iterate(itr, s)
1230+
it === nothing && return Some(op(op(op(v_1, v_2), op(v_3, v_4)), v_5))
1231+
a, s = it
1232+
v_6 = _mapreduce_start(f, op, itr, init, a)
1233+
it = iterate(itr, s)
1234+
it === nothing && return Some(op(op(op(v_1, v_2), op(v_3, v_4)), op(v_5, v_6)))
1235+
a, s = it
1236+
v_7 = _mapreduce_start(f, op, itr, init, a)
1237+
it = iterate(itr, s)
1238+
it === nothing && return Some(op(op(op(v_1, v_2), op(v_3, v_4)), op(op(v_5, v_6), v_7)))
1239+
a, s = it
1240+
v_8 = _mapreduce_start(f, op, itr, init, a)
1241+
for _ in 3:n>>3
1242+
@nexprs 8 N->begin
12271243
it = iterate(itr, s)
12281244
if it === nothing
1245+
N > 7 && (v_7 = op(v_7, f(a_7)))
1246+
N > 6 && (v_6 = op(v_6, f(a_6)))
1247+
N > 5 && (v_5 = op(v_5, f(a_5)))
1248+
N > 4 && (v_4 = op(v_4, f(a_4)))
12291249
N > 3 && (v_3 = op(v_3, f(a_3)))
12301250
N > 2 && (v_2 = op(v_2, f(a_2)))
12311251
N > 1 && (v_1 = op(v_1, f(a_1)))
1232-
return Some(op(op(v_1, v_2), op(v_3, v_4)))
1252+
return Some(op(op(op(v_1, v_2), op(v_3, v_4)), op(op(v_5, v_6), op(v_7, v_8))))
12331253
end
12341254
a_N, s = it
12351255
end
1236-
@nexprs 4 N->fa_N = f(a_N)
1237-
@nexprs 4 N->v_N = op(v_N, fa_N)
1256+
@nexprs 8 N->fa_N = f(a_N)
1257+
@nexprs 8 N->v_N = op(v_N, fa_N)
12381258
end
1239-
v = op(op(v_1, v_2), op(v_3, v_4))
1240-
i = (n>>2)*4
1241-
@nexprs 4 N->begin
1259+
v = op(op(op(v_1, v_2), op(v_3, v_4)), op(op(v_5, v_6), op(v_7, v_8)))
1260+
i = (n>>3)*8
1261+
@nexprs 8 N->begin
12421262
it = iterate(itr, s)
12431263
if it === nothing
12441264
return Some(v)

0 commit comments

Comments
 (0)