31
31
#include < utility>
32
32
#include < vector>
33
33
34
- #include " utils/strided_iters .hpp"
34
+ #include " utils/offset_utils .hpp"
35
35
#include " utils/type_dispatch.hpp"
36
36
37
37
namespace dpctl
@@ -45,6 +45,8 @@ namespace indexing
45
45
46
46
namespace py = pybind11;
47
47
48
+ using namespace dpctl ::tensor::offset_utils;
49
+
48
50
template <typename T> T ceiling_quotient (T n, T m)
49
51
{
50
52
return (n + m - 1 ) / m;
@@ -67,82 +69,6 @@ template <typename inputT,
67
69
typename TransformerT>
68
70
class inclusive_scan_rec_chunk_update_krn ;
69
71
70
- struct NoOpIndexer
71
- {
72
- size_t operator ()(size_t gid) const
73
- {
74
- return gid;
75
- }
76
- };
77
-
78
- struct StridedIndexer
79
- {
80
- StridedIndexer (int _nd,
81
- py::ssize_t _offset,
82
- py::ssize_t const *_packed_shape_strides)
83
- : nd(_nd), starting_offset(_offset),
84
- shape_strides (_packed_shape_strides)
85
- {
86
- }
87
-
88
- size_t operator ()(size_t gid) const
89
- {
90
- CIndexer_vector _ind (nd);
91
- py::ssize_t relative_offset (0 );
92
- _ind.get_displacement <const py::ssize_t *, const py::ssize_t *>(
93
- static_cast <py::ssize_t >(gid),
94
- shape_strides, // shape ptr
95
- shape_strides + nd, // strides ptr
96
- relative_offset);
97
- return starting_offset + relative_offset;
98
- }
99
-
100
- private:
101
- int nd;
102
- py::ssize_t starting_offset;
103
- py::ssize_t const *shape_strides;
104
- };
105
-
106
- struct Strided1DIndexer
107
- {
108
- Strided1DIndexer (py::ssize_t _offset, py::ssize_t _size, py::ssize_t _step)
109
- : offset(_offset), size(static_cast <size_t >(_size)), step(_step)
110
- {
111
- }
112
-
113
- size_t operator ()(size_t gid) const
114
- {
115
- // ensure 0 <= gid < size
116
- return static_cast <size_t >(offset +
117
- std::min<size_t >(gid, size - 1 ) * step);
118
- }
119
-
120
- private:
121
- py::ssize_t offset = 0 ;
122
- size_t size = 1 ;
123
- py::ssize_t step = 1 ;
124
- };
125
-
126
- struct Strided1DCyclicIndexer
127
- {
128
- Strided1DCyclicIndexer (py::ssize_t _offset,
129
- py::ssize_t _size,
130
- py::ssize_t _step)
131
- : offset(_offset), size(static_cast <size_t >(_size)), step(_step)
132
- {
133
- }
134
-
135
- size_t operator ()(size_t gid) const
136
- {
137
- return static_cast <size_t >(offset + (gid % size) * step);
138
- }
139
-
140
- private:
141
- py::ssize_t offset = 0 ;
142
- size_t size = 1 ;
143
- py::ssize_t step = 1 ;
144
- };
145
-
146
72
template <typename inputT, typename outputT> struct NonZeroIndicator
147
73
{
148
74
NonZeroIndicator () {}
@@ -200,9 +126,9 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
200
126
201
127
slmT slm_iscan_tmp (lws, cgh);
202
128
203
- cgh.parallel_for <class inclusive_scan_rec_local_scan_krn <inputT, outputT, n_wi, IndexerT, decltype (transformer)>>(
204
- sycl::nd_range< 1 >(gws, lws),
205
- [=](sycl::nd_item<1 > it)
129
+ cgh.parallel_for <class inclusive_scan_rec_local_scan_krn <
130
+ inputT, outputT, n_wi, IndexerT, decltype (transformer)>>(
131
+ sycl::nd_range< 1 >(gws, lws), [=](sycl::nd_item<1 > it)
206
132
{
207
133
auto chunk_gid = it.get_global_id (0 );
208
134
auto lid = it.get_local_id (0 );
@@ -245,8 +171,7 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
245
171
for (size_t m_wi = 0 ; m_wi < n_wi && i + m_wi < n_elems; ++m_wi) {
246
172
output[i + m_wi] = local_isum[m_wi];
247
173
}
248
- }
249
- );
174
+ });
250
175
});
251
176
252
177
sycl::event out_event = inc_scan_phase1_ev;
@@ -266,15 +191,14 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
266
191
// output[ chunk_size * (i + 1) + j] += temp[i]
267
192
auto e3 = exec_q.submit ([&](sycl::handler &cgh) {
268
193
cgh.depends_on (e2 );
269
- cgh.parallel_for <class inclusive_scan_rec_chunk_update_krn <inputT, outputT, IndexerT, decltype (transformer)>>(
270
- {n_elems},
271
- [=](auto wiid)
194
+ cgh.parallel_for <class inclusive_scan_rec_chunk_update_krn <
195
+ inputT, outputT, IndexerT, decltype (transformer)>>(
196
+ {n_elems}, [=](auto wiid)
272
197
{
273
198
auto gid = wiid[0 ];
274
199
auto i = (gid / chunk_size);
275
200
output[gid] += (i > 0 ) ? temp[i - 1 ] : 0 ;
276
- }
277
- );
201
+ });
278
202
});
279
203
280
204
sycl::event e4 = exec_q.submit ([&](sycl::handler &cgh) {
@@ -289,73 +213,6 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
289
213
return out_event;
290
214
}
291
215
292
- template <typename displacementT> struct TwoOffsets
293
- {
294
- TwoOffsets () : first_offset(0 ), second_offset(0 ) {}
295
- TwoOffsets (displacementT first_offset_, displacementT second_offset_)
296
- : first_offset(first_offset_), second_offset(second_offset_)
297
- {
298
- }
299
-
300
- displacementT get_first_offset () const
301
- {
302
- return first_offset;
303
- }
304
- displacementT get_second_offset () const
305
- {
306
- return second_offset;
307
- }
308
-
309
- private:
310
- displacementT first_offset = 0 ;
311
- displacementT second_offset = 0 ;
312
- };
313
-
314
- struct TwoOffsets_StridedIndexer
315
- {
316
- TwoOffsets_StridedIndexer (int common_nd,
317
- py::ssize_t first_offset_,
318
- py::ssize_t second_offset_,
319
- py::ssize_t const *_packed_shape_strides)
320
- : nd(common_nd), starting_first_offset(first_offset_),
321
- starting_second_offset (second_offset_),
322
- shape_strides(_packed_shape_strides)
323
- {
324
- }
325
-
326
- TwoOffsets<py::ssize_t > operator ()(py::ssize_t gid) const
327
- {
328
- CIndexer_vector _ind (nd);
329
- py::ssize_t relative_first_offset (0 );
330
- py::ssize_t relative_second_offset (0 );
331
- _ind.get_displacement <const py::ssize_t *, const py::ssize_t *>(
332
- gid,
333
- shape_strides, // shape ptr
334
- shape_strides + nd, // src strides ptr
335
- shape_strides + 2 * nd, // src strides ptr
336
- relative_first_offset, relative_second_offset);
337
- return TwoOffsets<py::ssize_t >(
338
- starting_first_offset + relative_first_offset,
339
- starting_second_offset + relative_second_offset);
340
- }
341
-
342
- private:
343
- int nd;
344
- py::ssize_t starting_first_offset;
345
- py::ssize_t starting_second_offset;
346
- py::ssize_t const *shape_strides;
347
- };
348
-
349
- struct TwoZeroOffsets_Indexer
350
- {
351
- TwoZeroOffsets_Indexer () {}
352
-
353
- TwoOffsets<py::ssize_t > operator ()(py::ssize_t ) const
354
- {
355
- return TwoOffsets<py::ssize_t >();
356
- }
357
- };
358
-
359
216
template <typename OrthogIndexerT,
360
217
typename MaskedSrcIndexerT,
361
218
typename MaskedDstIndexerT,
0 commit comments