Skip to content

Commit eed7e45

Browse files
committed
Made changes suggested by @oleksandr-pavlyk
1 parent 9d05a14 commit eed7e45

File tree

2 files changed

+100
-77
lines changed

2 files changed

+100
-77
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp

Lines changed: 97 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -58,47 +58,97 @@ template <typename T> struct boolean_predicate
5858
}
5959
};
6060

61-
template <typename inpT, typename outT, typename PredicateT, size_t wg_dim>
61+
template <typename inpT,
62+
typename outT,
63+
typename PredicateT,
64+
std::uint8_t wg_dim = 2>
6265
struct all_reduce_wg_contig
6366
{
64-
outT operator()(sycl::group<wg_dim> &wg,
67+
void operator()(sycl::nd_item<wg_dim> &ndit,
68+
outT *out,
69+
size_t &out_idx,
6570
const inpT *start,
6671
const inpT *end) const
6772
{
6873
PredicateT pred{};
69-
return static_cast<outT>(sycl::joint_all_of(wg, start, end, pred));
74+
auto wg = ndit.get_group();
75+
outT red_val_over_wg =
76+
static_cast<outT>(sycl::joint_all_of(wg, start, end, pred));
77+
78+
if (wg.leader()) {
79+
sycl::atomic_ref<outT, sycl::memory_order::relaxed,
80+
sycl::memory_scope::device,
81+
sycl::access::address_space::global_space>
82+
res_ref(out[out_idx]);
83+
res_ref.fetch_and(red_val_over_wg);
84+
}
7085
}
7186
};
7287

73-
template <typename inpT, typename outT, typename PredicateT, size_t wg_dim>
88+
template <typename inpT,
89+
typename outT,
90+
typename PredicateT,
91+
std::uint8_t wg_dim = 2>
7492
struct any_reduce_wg_contig
7593
{
76-
outT operator()(sycl::group<wg_dim> &wg,
94+
void operator()(sycl::nd_item<wg_dim> &ndit,
95+
outT *out,
96+
size_t &out_idx,
7797
const inpT *start,
7898
const inpT *end) const
7999
{
80100
PredicateT pred{};
81-
return static_cast<outT>(sycl::joint_any_of(wg, start, end, pred));
101+
auto wg = ndit.get_group();
102+
outT red_val_over_wg =
103+
static_cast<outT>(sycl::joint_any_of(wg, start, end, pred));
104+
105+
if (wg.leader()) {
106+
sycl::atomic_ref<outT, sycl::memory_order::relaxed,
107+
sycl::memory_scope::device,
108+
sycl::access::address_space::global_space>
109+
res_ref(out[out_idx]);
110+
res_ref.fetch_or(red_val_over_wg);
111+
}
82112
}
83113
};
84114

85-
template <typename T, typename PredicateT, size_t wg_dim>
86-
struct all_reduce_wg_strided
115+
template <typename T, std::uint8_t wg_dim = 2> struct all_reduce_wg_strided
87116
{
88-
T operator()(sycl::group<wg_dim> &wg, const T &local_val) const
117+
void operator()(sycl::nd_item<wg_dim> &ndit,
118+
T *out,
119+
const size_t &out_idx,
120+
const T &local_val) const
89121
{
90-
PredicateT pred{};
91-
return static_cast<T>(sycl::all_of_group(wg, local_val, pred));
122+
auto wg = ndit.get_group();
123+
T red_val_over_wg = static_cast<T>(sycl::all_of_group(wg, local_val));
124+
125+
if (wg.leader()) {
126+
sycl::atomic_ref<T, sycl::memory_order::relaxed,
127+
sycl::memory_scope::device,
128+
sycl::access::address_space::global_space>
129+
res_ref(out[out_idx]);
130+
res_ref.fetch_and(red_val_over_wg);
131+
}
92132
}
93133
};
94134

95-
template <typename T, typename PredicateT, size_t wg_dim>
96-
struct any_reduce_wg_strided
135+
template <typename T, std::uint8_t wg_dim = 2> struct any_reduce_wg_strided
97136
{
98-
T operator()(sycl::group<wg_dim> &wg, const T &local_val) const
137+
void operator()(sycl::nd_item<wg_dim> &ndit,
138+
T *out,
139+
const size_t &out_idx,
140+
const T &local_val) const
99141
{
100-
PredicateT pred{};
101-
return static_cast<T>(sycl::any_of_group(wg, local_val, pred));
142+
auto wg = ndit.get_group();
143+
T red_val_over_wg = static_cast<T>(sycl::any_of_group(wg, local_val));
144+
145+
if (wg.leader()) {
146+
sycl::atomic_ref<T, sycl::memory_order::relaxed,
147+
sycl::memory_scope::device,
148+
sycl::access::address_space::global_space>
149+
res_ref(out[out_idx]);
150+
res_ref.fetch_or(red_val_over_wg);
151+
}
102152
}
103153
};
104154

@@ -137,8 +187,10 @@ struct SequentialBooleanReduction
137187
{
138188

139189
auto inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]);
140-
const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset();
141-
const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset();
190+
const size_t &inp_iter_offset =
191+
inp_out_iter_offsets_.get_first_offset();
192+
const size_t &out_iter_offset =
193+
inp_out_iter_offsets_.get_second_offset();
142194

143195
outT red_val(identity_);
144196
for (size_t m = 0; m < reduction_max_gid_; ++m) {
@@ -156,26 +208,24 @@ struct SequentialBooleanReduction
156208
}
157209
};
158210

159-
template <typename argT, typename outT, typename ReductionOp, typename GroupOp>
211+
template <typename argT, typename outT, typename GroupOp>
160212
struct ContigBooleanReduction
161213
{
162214
private:
163215
const argT *inp_ = nullptr;
164216
outT *out_ = nullptr;
165-
ReductionOp reduction_op_;
166217
GroupOp group_op_;
167218
size_t reduction_max_gid_ = 0;
168219
size_t reductions_per_wi = 16;
169220

170221
public:
171222
ContigBooleanReduction(const argT *inp,
172223
outT *res,
173-
ReductionOp reduction_op,
174224
GroupOp group_op,
175225
size_t reduction_size,
176226
size_t reduction_size_per_wi)
177-
: inp_(inp), out_(res), reduction_op_(reduction_op),
178-
group_op_(group_op), reduction_max_gid_(reduction_size),
227+
: inp_(inp), out_(res), group_op_(group_op),
228+
reduction_max_gid_(reduction_size),
179229
reductions_per_wi(reduction_size_per_wi)
180230
{
181231
}
@@ -185,30 +235,15 @@ struct ContigBooleanReduction
185235

186236
size_t reduction_id = it.get_group(0);
187237
size_t reduction_batch_id = it.get_group(1);
188-
189-
auto work_group = it.get_group();
190238
size_t wg_size = it.get_local_range(1);
191239

192240
size_t base = reduction_id * reduction_max_gid_;
193241
size_t start = base + reduction_batch_id * wg_size * reductions_per_wi;
194242
size_t end = std::min((start + (reductions_per_wi * wg_size)),
195243
base + reduction_max_gid_);
196-
197-
// reduction to the work group level is performed
198-
// inside group_op
199-
outT red_val_over_wg = group_op_(work_group, inp_ + start, inp_ + end);
200-
201-
if (work_group.leader()) {
202-
sycl::atomic_ref<outT, sycl::memory_order::relaxed,
203-
sycl::memory_scope::device,
204-
sycl::access::address_space::global_space>
205-
res_ref(out_[reduction_id]);
206-
outT read_val = res_ref.load();
207-
outT new_val{};
208-
do {
209-
new_val = reduction_op_(read_val, red_val_over_wg);
210-
} while (!res_ref.compare_exchange_strong(read_val, new_val));
211-
}
244+
// reduction and atomic operations are performed
245+
// in group_op_
246+
group_op_(it, out_, reduction_id, inp_ + start, inp_ + end);
212247
}
213248
};
214249

@@ -223,7 +258,7 @@ typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
223258
py::ssize_t,
224259
const std::vector<sycl::event> &);
225260

226-
template <typename T1, typename T2, typename T3, typename T4>
261+
template <typename T1, typename T2, typename T3>
227262
class boolean_reduction_contig_krn;
228263

229264
template <typename T1, typename T2, typename T3, typename T4, typename T5>
@@ -298,7 +333,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
298333
red_ev = exec_q.submit([&](sycl::handler &cgh) {
299334
cgh.depends_on(init_ev);
300335

301-
constexpr size_t group_dim = 2;
336+
constexpr std::uint8_t group_dim = 2;
302337

303338
constexpr size_t preferred_reductions_per_wi = 4;
304339
size_t reductions_per_wi =
@@ -314,11 +349,11 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
314349
sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
315350
auto lws = sycl::range<group_dim>{1, wg};
316351

317-
cgh.parallel_for<class boolean_reduction_contig_krn<
318-
argTy, resTy, RedOpT, GroupOpT>>(
352+
cgh.parallel_for<
353+
class boolean_reduction_contig_krn<argTy, resTy, GroupOpT>>(
319354
sycl::nd_range<group_dim>(gws, lws),
320-
ContigBooleanReduction<argTy, resTy, RedOpT, GroupOpT>(
321-
arg_tp, res_tp, RedOpT(), GroupOpT(), reduction_nelems,
355+
ContigBooleanReduction<argTy, resTy, GroupOpT>(
356+
arg_tp, res_tp, GroupOpT(), reduction_nelems,
322357
reductions_per_wi));
323358
});
324359
}
@@ -332,7 +367,7 @@ template <typename fnT, typename srcTy> struct AllContigFactory
332367
using resTy = std::int32_t;
333368
using RedOpT = sycl::logical_and<resTy>;
334369
using GroupOpT =
335-
all_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>, 2>;
370+
all_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>>;
336371

337372
return dpctl::tensor::kernels::boolean_reduction_contig_impl<
338373
srcTy, resTy, RedOpT, GroupOpT>;
@@ -346,7 +381,7 @@ template <typename fnT, typename srcTy> struct AnyContigFactory
346381
using resTy = std::int32_t;
347382
using RedOpT = sycl::logical_or<resTy>;
348383
using GroupOpT =
349-
any_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>, 2>;
384+
any_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>>;
350385

351386
return dpctl::tensor::kernels::boolean_reduction_contig_impl<
352387
srcTy, resTy, RedOpT, GroupOpT>;
@@ -400,8 +435,10 @@ struct StridedBooleanReduction
400435
size_t wg_size = it.get_local_range(1);
401436

402437
auto inp_out_iter_offsets_ = inp_out_iter_indexer_(reduction_id);
403-
const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset();
404-
const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset();
438+
const size_t &inp_iter_offset =
439+
inp_out_iter_offsets_.get_first_offset();
440+
const size_t &out_iter_offset =
441+
inp_out_iter_offsets_.get_second_offset();
405442

406443
outT local_red_val(identity_);
407444
size_t arg_reduce_gid0 =
@@ -416,28 +453,15 @@ struct StridedBooleanReduction
416453

417454
// must convert to boolean first to handle nans
418455
using dpctl::tensor::type_utils::convert_impl;
419-
outT val = convert_impl<bool, argT>(inp_[inp_offset]);
456+
bool val = convert_impl<bool, argT>(inp_[inp_offset]);
420457

421-
local_red_val = reduction_op_(local_red_val, val);
458+
local_red_val =
459+
reduction_op_(local_red_val, static_cast<outT>(val));
422460
}
423461
}
424-
425-
// reduction to the work group level is performed
426-
// inside group_op
427-
auto work_group = it.get_group();
428-
outT red_val_over_wg = group_op_(work_group, local_red_val);
429-
430-
if (work_group.leader()) {
431-
sycl::atomic_ref<outT, sycl::memory_order::relaxed,
432-
sycl::memory_scope::device,
433-
sycl::access::address_space::global_space>
434-
res_ref(out_[out_iter_offset]);
435-
outT read_val = res_ref.load();
436-
outT new_val{};
437-
do {
438-
new_val = reduction_op_(read_val, red_val_over_wg);
439-
} while (!res_ref.compare_exchange_strong(read_val, new_val));
440-
}
462+
// reduction and atomic operations are performed
463+
// in group_op_
464+
group_op_(it, out_, out_iter_offset, local_red_val);
441465
}
442466
};
443467

@@ -541,7 +565,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
541565
red_ev = exec_q.submit([&](sycl::handler &cgh) {
542566
cgh.depends_on(res_init_ev);
543567

544-
constexpr size_t group_dim = 2;
568+
constexpr std::uint8_t group_dim = 2;
545569

546570
using InputOutputIterIndexerT =
547571
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
@@ -589,8 +613,7 @@ template <typename fnT, typename srcTy> struct AllStridedFactory
589613
{
590614
using resTy = std::int32_t;
591615
using RedOpT = sycl::logical_and<resTy>;
592-
using GroupOpT =
593-
all_reduce_wg_strided<resTy, boolean_predicate<srcTy>, 2>;
616+
using GroupOpT = all_reduce_wg_strided<resTy>;
594617

595618
return dpctl::tensor::kernels::boolean_reduction_strided_impl<
596619
srcTy, resTy, RedOpT, GroupOpT>;
@@ -603,8 +626,7 @@ template <typename fnT, typename srcTy> struct AnyStridedFactory
603626
{
604627
using resTy = std::int32_t;
605628
using RedOpT = sycl::logical_or<resTy>;
606-
using GroupOpT =
607-
any_reduce_wg_strided<resTy, boolean_predicate<srcTy>, 2>;
629+
using GroupOpT = any_reduce_wg_strided<resTy>;
608630

609631
return dpctl::tensor::kernels::boolean_reduction_strided_impl<
610632
srcTy, resTy, RedOpT, GroupOpT>;

dpctl/tensor/libtensor/source/boolean_reductions.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "kernels/boolean_reductions.hpp"
4242
#include "simplify_iteration_space.hpp"
4343
#include "utils/memory_overlap.hpp"
44+
#include "utils/offset_utils.hpp"
4445
#include "utils/type_utils.hpp"
4546

4647
namespace py = pybind11;
@@ -145,7 +146,7 @@ py_boolean_reduction(dpctl::tensor::usm_ndarray src,
145146
constexpr int int32_typeid = static_cast<int>(td_ns::typenum_t::INT32);
146147
if (dst_typeid != int32_typeid) {
147148
throw py::value_error(
148-
"Unexact data type of destination array, expecting 'int32'");
149+
"Unexpected data type of destination array, expecting 'int32'");
149150
}
150151

151152
bool is_src_c_contig = src.is_c_contiguous();
@@ -156,7 +157,7 @@ py_boolean_reduction(dpctl::tensor::usm_ndarray src,
156157
if ((is_src_c_contig && is_dst_c_contig) ||
157158
(is_src_f_contig && dst_nd == 0)) {
158159
auto fn = contig_dispatch_vector[src_typeid];
159-
py::ssize_t zero_offset = 0;
160+
constexpr py::ssize_t zero_offset = 0;
160161

161162
auto red_ev = fn(exec_q, dst_nelems, red_nelems, src_data, dst_data,
162163
zero_offset, zero_offset, zero_offset, depends);

0 commit comments

Comments
 (0)