@@ -58,47 +58,97 @@ template <typename T> struct boolean_predicate
58
58
}
59
59
};
60
60
61
- template <typename inpT, typename outT, typename PredicateT, size_t wg_dim>
61
+ template <typename inpT,
62
+ typename outT,
63
+ typename PredicateT,
64
+ std::uint8_t wg_dim = 2 >
62
65
struct all_reduce_wg_contig
63
66
{
64
- outT operator ()(sycl::group<wg_dim> &wg,
67
+ void operator ()(sycl::nd_item<wg_dim> &ndit,
68
+ outT *out,
69
+ size_t &out_idx,
65
70
const inpT *start,
66
71
const inpT *end) const
67
72
{
68
73
PredicateT pred{};
69
- return static_cast <outT>(sycl::joint_all_of (wg, start, end, pred));
74
+ auto wg = ndit.get_group ();
75
+ outT red_val_over_wg =
76
+ static_cast <outT>(sycl::joint_all_of (wg, start, end, pred));
77
+
78
+ if (wg.leader ()) {
79
+ sycl::atomic_ref<outT, sycl::memory_order::relaxed,
80
+ sycl::memory_scope::device,
81
+ sycl::access::address_space::global_space>
82
+ res_ref (out[out_idx]);
83
+ res_ref.fetch_and (red_val_over_wg);
84
+ }
70
85
}
71
86
};
72
87
73
- template <typename inpT, typename outT, typename PredicateT, size_t wg_dim>
88
+ template <typename inpT,
89
+ typename outT,
90
+ typename PredicateT,
91
+ std::uint8_t wg_dim = 2 >
74
92
struct any_reduce_wg_contig
75
93
{
76
- outT operator ()(sycl::group<wg_dim> &wg,
94
+ void operator ()(sycl::nd_item<wg_dim> &ndit,
95
+ outT *out,
96
+ size_t &out_idx,
77
97
const inpT *start,
78
98
const inpT *end) const
79
99
{
80
100
PredicateT pred{};
81
- return static_cast <outT>(sycl::joint_any_of (wg, start, end, pred));
101
+ auto wg = ndit.get_group ();
102
+ outT red_val_over_wg =
103
+ static_cast <outT>(sycl::joint_any_of (wg, start, end, pred));
104
+
105
+ if (wg.leader ()) {
106
+ sycl::atomic_ref<outT, sycl::memory_order::relaxed,
107
+ sycl::memory_scope::device,
108
+ sycl::access::address_space::global_space>
109
+ res_ref (out[out_idx]);
110
+ res_ref.fetch_or (red_val_over_wg);
111
+ }
82
112
}
83
113
};
84
114
85
- template <typename T, typename PredicateT, size_t wg_dim>
86
- struct all_reduce_wg_strided
115
+ template <typename T, std::uint8_t wg_dim = 2 > struct all_reduce_wg_strided
87
116
{
88
- T operator ()(sycl::group<wg_dim> &wg, const T &local_val) const
117
+ void operator ()(sycl::nd_item<wg_dim> &ndit,
118
+ T *out,
119
+ const size_t &out_idx,
120
+ const T &local_val) const
89
121
{
90
- PredicateT pred{};
91
- return static_cast <T>(sycl::all_of_group (wg, local_val, pred));
122
+ auto wg = ndit.get_group ();
123
+ T red_val_over_wg = static_cast <T>(sycl::all_of_group (wg, local_val));
124
+
125
+ if (wg.leader ()) {
126
+ sycl::atomic_ref<T, sycl::memory_order::relaxed,
127
+ sycl::memory_scope::device,
128
+ sycl::access::address_space::global_space>
129
+ res_ref (out[out_idx]);
130
+ res_ref.fetch_and (red_val_over_wg);
131
+ }
92
132
}
93
133
};
94
134
95
- template <typename T, typename PredicateT, size_t wg_dim>
96
- struct any_reduce_wg_strided
135
+ template <typename T, std::uint8_t wg_dim = 2 > struct any_reduce_wg_strided
97
136
{
98
- T operator ()(sycl::group<wg_dim> &wg, const T &local_val) const
137
+ void operator ()(sycl::nd_item<wg_dim> &ndit,
138
+ T *out,
139
+ const size_t &out_idx,
140
+ const T &local_val) const
99
141
{
100
- PredicateT pred{};
101
- return static_cast <T>(sycl::any_of_group (wg, local_val, pred));
142
+ auto wg = ndit.get_group ();
143
+ T red_val_over_wg = static_cast <T>(sycl::any_of_group (wg, local_val));
144
+
145
+ if (wg.leader ()) {
146
+ sycl::atomic_ref<T, sycl::memory_order::relaxed,
147
+ sycl::memory_scope::device,
148
+ sycl::access::address_space::global_space>
149
+ res_ref (out[out_idx]);
150
+ res_ref.fetch_or (red_val_over_wg);
151
+ }
102
152
}
103
153
};
104
154
@@ -137,8 +187,10 @@ struct SequentialBooleanReduction
137
187
{
138
188
139
189
auto inp_out_iter_offsets_ = inp_out_iter_indexer_ (id[0 ]);
140
- const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset ();
141
- const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset ();
190
+ const size_t &inp_iter_offset =
191
+ inp_out_iter_offsets_.get_first_offset ();
192
+ const size_t &out_iter_offset =
193
+ inp_out_iter_offsets_.get_second_offset ();
142
194
143
195
outT red_val (identity_);
144
196
for (size_t m = 0 ; m < reduction_max_gid_; ++m) {
@@ -156,26 +208,24 @@ struct SequentialBooleanReduction
156
208
}
157
209
};
158
210
159
- template <typename argT, typename outT, typename ReductionOp, typename GroupOp>
211
+ template <typename argT, typename outT, typename GroupOp>
160
212
struct ContigBooleanReduction
161
213
{
162
214
private:
163
215
const argT *inp_ = nullptr ;
164
216
outT *out_ = nullptr ;
165
- ReductionOp reduction_op_;
166
217
GroupOp group_op_;
167
218
size_t reduction_max_gid_ = 0 ;
168
219
size_t reductions_per_wi = 16 ;
169
220
170
221
public:
171
222
ContigBooleanReduction (const argT *inp,
172
223
outT *res,
173
- ReductionOp reduction_op,
174
224
GroupOp group_op,
175
225
size_t reduction_size,
176
226
size_t reduction_size_per_wi)
177
- : inp_(inp), out_(res), reduction_op_(reduction_op ),
178
- group_op_ (group_op), reduction_max_gid_(reduction_size),
227
+ : inp_(inp), out_(res), group_op_(group_op ),
228
+ reduction_max_gid_ (reduction_size),
179
229
reductions_per_wi(reduction_size_per_wi)
180
230
{
181
231
}
@@ -185,30 +235,15 @@ struct ContigBooleanReduction
185
235
186
236
size_t reduction_id = it.get_group (0 );
187
237
size_t reduction_batch_id = it.get_group (1 );
188
-
189
- auto work_group = it.get_group ();
190
238
size_t wg_size = it.get_local_range (1 );
191
239
192
240
size_t base = reduction_id * reduction_max_gid_;
193
241
size_t start = base + reduction_batch_id * wg_size * reductions_per_wi;
194
242
size_t end = std::min ((start + (reductions_per_wi * wg_size)),
195
243
base + reduction_max_gid_);
196
-
197
- // reduction to the work group level is performed
198
- // inside group_op
199
- outT red_val_over_wg = group_op_ (work_group, inp_ + start, inp_ + end);
200
-
201
- if (work_group.leader ()) {
202
- sycl::atomic_ref<outT, sycl::memory_order::relaxed,
203
- sycl::memory_scope::device,
204
- sycl::access::address_space::global_space>
205
- res_ref (out_[reduction_id]);
206
- outT read_val = res_ref.load ();
207
- outT new_val{};
208
- do {
209
- new_val = reduction_op_ (read_val, red_val_over_wg);
210
- } while (!res_ref.compare_exchange_strong (read_val, new_val));
211
- }
244
+ // reduction and atomic operations are performed
245
+ // in group_op_
246
+ group_op_ (it, out_, reduction_id, inp_ + start, inp_ + end);
212
247
}
213
248
};
214
249
@@ -223,7 +258,7 @@ typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
223
258
py::ssize_t ,
224
259
const std::vector<sycl::event> &);
225
260
226
- template <typename T1, typename T2, typename T3, typename T4 >
261
+ template <typename T1, typename T2, typename T3>
227
262
class boolean_reduction_contig_krn ;
228
263
229
264
template <typename T1, typename T2, typename T3, typename T4, typename T5>
@@ -298,7 +333,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
298
333
red_ev = exec_q.submit ([&](sycl::handler &cgh) {
299
334
cgh.depends_on (init_ev);
300
335
301
- constexpr size_t group_dim = 2 ;
336
+ constexpr std:: uint8_t group_dim = 2 ;
302
337
303
338
constexpr size_t preferred_reductions_per_wi = 4 ;
304
339
size_t reductions_per_wi =
@@ -314,11 +349,11 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
314
349
sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
315
350
auto lws = sycl::range<group_dim>{1 , wg};
316
351
317
- cgh.parallel_for <class boolean_reduction_contig_krn <
318
- argTy, resTy, RedOpT , GroupOpT>>(
352
+ cgh.parallel_for <
353
+ class boolean_reduction_contig_krn < argTy, resTy, GroupOpT>>(
319
354
sycl::nd_range<group_dim>(gws, lws),
320
- ContigBooleanReduction<argTy, resTy, RedOpT, GroupOpT>(
321
- arg_tp, res_tp, RedOpT (), GroupOpT (), reduction_nelems,
355
+ ContigBooleanReduction<argTy, resTy, GroupOpT>(
356
+ arg_tp, res_tp, GroupOpT (), reduction_nelems,
322
357
reductions_per_wi));
323
358
});
324
359
}
@@ -332,7 +367,7 @@ template <typename fnT, typename srcTy> struct AllContigFactory
332
367
using resTy = std::int32_t ;
333
368
using RedOpT = sycl::logical_and<resTy>;
334
369
using GroupOpT =
335
- all_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>, 2 >;
370
+ all_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>>;
336
371
337
372
return dpctl::tensor::kernels::boolean_reduction_contig_impl<
338
373
srcTy, resTy, RedOpT, GroupOpT>;
@@ -346,7 +381,7 @@ template <typename fnT, typename srcTy> struct AnyContigFactory
346
381
using resTy = std::int32_t ;
347
382
using RedOpT = sycl::logical_or<resTy>;
348
383
using GroupOpT =
349
- any_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>, 2 >;
384
+ any_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>>;
350
385
351
386
return dpctl::tensor::kernels::boolean_reduction_contig_impl<
352
387
srcTy, resTy, RedOpT, GroupOpT>;
@@ -400,8 +435,10 @@ struct StridedBooleanReduction
400
435
size_t wg_size = it.get_local_range (1 );
401
436
402
437
auto inp_out_iter_offsets_ = inp_out_iter_indexer_ (reduction_id);
403
- const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset ();
404
- const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset ();
438
+ const size_t &inp_iter_offset =
439
+ inp_out_iter_offsets_.get_first_offset ();
440
+ const size_t &out_iter_offset =
441
+ inp_out_iter_offsets_.get_second_offset ();
405
442
406
443
outT local_red_val (identity_);
407
444
size_t arg_reduce_gid0 =
@@ -416,28 +453,15 @@ struct StridedBooleanReduction
416
453
417
454
// must convert to boolean first to handle nans
418
455
using dpctl::tensor::type_utils::convert_impl;
419
- outT val = convert_impl<bool , argT>(inp_[inp_offset]);
456
+ bool val = convert_impl<bool , argT>(inp_[inp_offset]);
420
457
421
- local_red_val = reduction_op_ (local_red_val, val);
458
+ local_red_val =
459
+ reduction_op_ (local_red_val, static_cast <outT>(val));
422
460
}
423
461
}
424
-
425
- // reduction to the work group level is performed
426
- // inside group_op
427
- auto work_group = it.get_group ();
428
- outT red_val_over_wg = group_op_ (work_group, local_red_val);
429
-
430
- if (work_group.leader ()) {
431
- sycl::atomic_ref<outT, sycl::memory_order::relaxed,
432
- sycl::memory_scope::device,
433
- sycl::access::address_space::global_space>
434
- res_ref (out_[out_iter_offset]);
435
- outT read_val = res_ref.load ();
436
- outT new_val{};
437
- do {
438
- new_val = reduction_op_ (read_val, red_val_over_wg);
439
- } while (!res_ref.compare_exchange_strong (read_val, new_val));
440
- }
462
+ // reduction and atomic operations are performed
463
+ // in group_op_
464
+ group_op_ (it, out_, out_iter_offset, local_red_val);
441
465
}
442
466
};
443
467
@@ -541,7 +565,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
541
565
red_ev = exec_q.submit ([&](sycl::handler &cgh) {
542
566
cgh.depends_on (res_init_ev);
543
567
544
- constexpr size_t group_dim = 2 ;
568
+ constexpr std:: uint8_t group_dim = 2 ;
545
569
546
570
using InputOutputIterIndexerT =
547
571
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
@@ -589,8 +613,7 @@ template <typename fnT, typename srcTy> struct AllStridedFactory
589
613
{
590
614
using resTy = std::int32_t ;
591
615
using RedOpT = sycl::logical_and<resTy>;
592
- using GroupOpT =
593
- all_reduce_wg_strided<resTy, boolean_predicate<srcTy>, 2 >;
616
+ using GroupOpT = all_reduce_wg_strided<resTy>;
594
617
595
618
return dpctl::tensor::kernels::boolean_reduction_strided_impl<
596
619
srcTy, resTy, RedOpT, GroupOpT>;
@@ -603,8 +626,7 @@ template <typename fnT, typename srcTy> struct AnyStridedFactory
603
626
{
604
627
using resTy = std::int32_t ;
605
628
using RedOpT = sycl::logical_or<resTy>;
606
- using GroupOpT =
607
- any_reduce_wg_strided<resTy, boolean_predicate<srcTy>, 2 >;
629
+ using GroupOpT = any_reduce_wg_strided<resTy>;
608
630
609
631
return dpctl::tensor::kernels::boolean_reduction_strided_impl<
610
632
srcTy, resTy, RedOpT, GroupOpT>;
0 commit comments