@@ -1118,16 +1118,16 @@ end
1118
1118
function _mapreduce_kernel_commutative (f, op, A, init, inds, leading= (), trailing= ())
1119
1119
i1, iN = firstindex (inds), lastindex (inds)
1120
1120
n = length (inds)
1121
- @nexprs 4 N-> a_N = @inbounds A[leading... , inds[i1+ (N- 1 )], trailing... ]
1122
- @nexprs 4 N-> v_N = _mapreduce_start (f, op, A, init, a_N)
1123
- for batch in 1 : (n>> 2 )- 1
1124
- i = i1 + batch* 4
1125
- @nexprs 4 N-> a_N = @inbounds A[leading... , inds[i+ (N- 1 )], trailing... ]
1126
- @nexprs 4 N-> fa_N = f (a_N)
1127
- @nexprs 4 N-> v_N = op (v_N, fa_N)
1128
- end
1129
- v = op (op (v_1, v_2), op (v_3, v_4))
1130
- i = i1 + (n>> 2 ) * 4 - 1
1121
+ @nexprs 8 N-> a_N = @inbounds A[leading... , inds[i1+ (N- 1 )], trailing... ]
1122
+ @nexprs 8 N-> v_N = _mapreduce_start (f, op, A, init, a_N)
1123
+ for batch in 1 : (n>> 3 )- 1
1124
+ i = i1 + batch* 8
1125
+ @nexprs 8 N-> a_N = @inbounds A[leading... , inds[i+ (N- 1 )], trailing... ]
1126
+ @nexprs 8 N-> fa_N = f (a_N)
1127
+ @nexprs 8 N-> v_N = op (v_N, fa_N)
1128
+ end
1129
+ v = op (op (op ( v_1, v_2), op (v_3, v_4)), op ( op (v_5, v_6), op (v_7, v_8) ))
1130
+ i = i1 + (n>> 3 ) * 8 - 1
1131
1131
i == iN && return v
1132
1132
for i in i+ 1 : iN
1133
1133
ai = @inbounds A[leading... , inds[i], trailing... ]
@@ -1180,23 +1180,23 @@ function mapreduce_kernel_commutative(f, op, itr, init, ::Union{HasLength, HasSh
1180
1180
end
1181
1181
return v_1, s
1182
1182
end
1183
- @nexprs 3 n-> begin
1183
+ @nexprs 7 n-> begin
1184
1184
it = iterate (itr, s)
1185
1185
it === nothing && _throw_iterator_assertion_error ()
1186
1186
a, s = it
1187
1187
v_{n+ 1 } = _mapreduce_start (f, op, itr, init, a)
1188
1188
end
1189
- i = 4
1190
- for outer i in 8 : 4 : n
1191
- @nexprs 4 n-> begin
1189
+ i = 8
1190
+ for outer i in 16 : 8 : n
1191
+ @nexprs 8 n-> begin
1192
1192
it = iterate (itr, s)
1193
1193
it === nothing && _throw_iterator_assertion_error ()
1194
1194
a_n, s = it
1195
1195
end
1196
- @nexprs 4 n-> fa_n = f (a_n)
1197
- @nexprs 4 n-> v_n = op (v_n, fa_n)
1196
+ @nexprs 8 n-> fa_n = f (a_n)
1197
+ @nexprs 8 n-> v_n = op (v_n, fa_n)
1198
1198
end
1199
- v = op (op (v_1, v_2), op (v_3, v_4))
1199
+ v = op (op (op ( v_1, v_2), op (v_3, v_4)), op ( op (v_5, v_6), op (v_7, v_8) ))
1200
1200
for _ in i+ 1 : n
1201
1201
it = iterate (itr, s)
1202
1202
it === nothing && _throw_iterator_assertion_error ()
@@ -1222,23 +1222,43 @@ function mapreduce_kernel_commutative(f, op, itr, init, ::IteratorSize, n, state
1222
1222
it === nothing && return Some (op (op (v_1, v_2), v_3))
1223
1223
a, s = it
1224
1224
v_4 = _mapreduce_start (f, op, itr, init, a)
1225
- for _ in 2 : n>> 2
1226
- @nexprs 4 N-> begin
1225
+ it = iterate (itr, s)
1226
+ it === nothing && return Some (op (op (v_1, v_2), op (v_3, v_4)))
1227
+ a, s = it
1228
+ v_5 = _mapreduce_start (f, op, itr, init, a)
1229
+ it = iterate (itr, s)
1230
+ it === nothing && return Some (op (op (op (v_1, v_2), op (v_3, v_4)), v_5))
1231
+ a, s = it
1232
+ v_6 = _mapreduce_start (f, op, itr, init, a)
1233
+ it = iterate (itr, s)
1234
+ it === nothing && return Some (op (op (op (v_1, v_2), op (v_3, v_4)), op (v_5, v_6)))
1235
+ a, s = it
1236
+ v_7 = _mapreduce_start (f, op, itr, init, a)
1237
+ it = iterate (itr, s)
1238
+ it === nothing && return Some (op (op (op (v_1, v_2), op (v_3, v_4)), op (op (v_5, v_6), v_7)))
1239
+ a, s = it
1240
+ v_8 = _mapreduce_start (f, op, itr, init, a)
1241
+ for _ in 3 : n>> 3
1242
+ @nexprs 8 N-> begin
1227
1243
it = iterate (itr, s)
1228
1244
if it === nothing
1245
+ N > 7 && (v_7 = op (v_7, f (a_7)))
1246
+ N > 6 && (v_6 = op (v_6, f (a_6)))
1247
+ N > 5 && (v_5 = op (v_5, f (a_5)))
1248
+ N > 4 && (v_4 = op (v_4, f (a_4)))
1229
1249
N > 3 && (v_3 = op (v_3, f (a_3)))
1230
1250
N > 2 && (v_2 = op (v_2, f (a_2)))
1231
1251
N > 1 && (v_1 = op (v_1, f (a_1)))
1232
- return Some (op (op (v_1, v_2), op (v_3, v_4)))
1252
+ return Some (op (op (op ( v_1, v_2), op (v_3, v_4)), op ( op (v_5, v_6), op (v_7, v_8) )))
1233
1253
end
1234
1254
a_N, s = it
1235
1255
end
1236
- @nexprs 4 N-> fa_N = f (a_N)
1237
- @nexprs 4 N-> v_N = op (v_N, fa_N)
1256
+ @nexprs 8 N-> fa_N = f (a_N)
1257
+ @nexprs 8 N-> v_N = op (v_N, fa_N)
1238
1258
end
1239
- v = op (op (v_1, v_2), op (v_3, v_4))
1240
- i = (n>> 2 ) * 4
1241
- @nexprs 4 N-> begin
1259
+ v = op (op (op ( v_1, v_2), op (v_3, v_4)), op ( op (v_5, v_6), op (v_7, v_8) ))
1260
+ i = (n>> 3 ) * 8
1261
+ @nexprs 8 N-> begin
1242
1262
it = iterate (itr, s)
1243
1263
if it === nothing
1244
1264
return Some (v)
0 commit comments