@@ -133,6 +133,7 @@ struct NoOpIndexer
133
133
}
134
134
};
135
135
136
+ /* @brief Indexer with shape and strides arrays of same size are packed */
136
137
struct StridedIndexer
137
138
{
138
139
StridedIndexer (int _nd,
@@ -143,24 +144,76 @@ struct StridedIndexer
143
144
{
144
145
}
145
146
147
+ size_t operator ()(py::ssize_t gid) const
148
+ {
149
+ return compute_offset (gid);
150
+ }
151
+
146
152
size_t operator ()(size_t gid) const
153
+ {
154
+ return compute_offset (static_cast <py::ssize_t >(gid));
155
+ }
156
+
157
+ private:
158
+ int nd;
159
+ py::ssize_t starting_offset;
160
+ py::ssize_t const *shape_strides;
161
+
162
+ size_t compute_offset (py::ssize_t gid) const
147
163
{
148
164
using dpctl::tensor::strides::CIndexer_vector;
149
165
150
166
CIndexer_vector _ind (nd);
151
167
py::ssize_t relative_offset (0 );
152
168
_ind.get_displacement <const py::ssize_t *, const py::ssize_t *>(
153
- static_cast <py:: ssize_t >( gid) ,
169
+ gid,
154
170
shape_strides, // shape ptr
155
171
shape_strides + nd, // strides ptr
156
172
relative_offset);
157
173
return starting_offset + relative_offset;
158
174
}
175
+ };
176
+
177
+ /* @brief Indexer with shape, strides provided separately */
178
+ struct UnpackedStridedIndexer
179
+ {
180
+ UnpackedStridedIndexer (int _nd,
181
+ py::ssize_t _offset,
182
+ py::ssize_t const *_shape,
183
+ py::ssize_t const *_strides)
184
+ : nd(_nd), starting_offset(_offset), shape(_shape), strides(_strides)
185
+ {
186
+ }
187
+
188
+ size_t operator ()(py::ssize_t gid) const
189
+ {
190
+ return compute_offset (gid);
191
+ }
192
+
193
+ size_t operator ()(size_t gid) const
194
+ {
195
+ return compute_offset (static_cast <py::ssize_t >(gid));
196
+ }
159
197
160
198
private:
161
199
int nd;
162
200
py::ssize_t starting_offset;
163
- py::ssize_t const *shape_strides;
201
+ py::ssize_t const *shape;
202
+ py::ssize_t const *strides;
203
+
204
+ size_t compute_offset (py::ssize_t gid) const
205
+ {
206
+ using dpctl::tensor::strides::CIndexer_vector;
207
+
208
+ CIndexer_vector _ind (nd);
209
+ py::ssize_t relative_offset (0 );
210
+ _ind.get_displacement <const py::ssize_t *, const py::ssize_t *>(
211
+ gid,
212
+ shape, // shape ptr
213
+ strides, // strides ptr
214
+ relative_offset);
215
+ return starting_offset + relative_offset;
216
+ }
164
217
};
165
218
166
219
struct Strided1DIndexer
@@ -206,7 +259,8 @@ struct Strided1DCyclicIndexer
206
259
template <typename displacementT> struct TwoOffsets
207
260
{
208
261
TwoOffsets () : first_offset(0 ), second_offset(0 ) {}
209
- TwoOffsets (displacementT first_offset_, displacementT second_offset_)
262
+ TwoOffsets (const displacementT &first_offset_,
263
+ const displacementT &second_offset_)
210
264
: first_offset(first_offset_), second_offset(second_offset_)
211
265
{
212
266
}
@@ -238,6 +292,22 @@ struct TwoOffsets_StridedIndexer
238
292
}
239
293
240
294
TwoOffsets<py::ssize_t > operator ()(py::ssize_t gid) const
295
+ {
296
+ return compute_offsets (gid);
297
+ }
298
+
299
+ TwoOffsets<py::ssize_t > operator ()(size_t gid) const
300
+ {
301
+ return compute_offsets (static_cast <py::ssize_t >(gid));
302
+ }
303
+
304
+ private:
305
+ int nd;
306
+ py::ssize_t starting_first_offset;
307
+ py::ssize_t starting_second_offset;
308
+ py::ssize_t const *shape_strides;
309
+
310
+ TwoOffsets<py::ssize_t > compute_offsets (py::ssize_t gid) const
241
311
{
242
312
using dpctl::tensor::strides::CIndexer_vector;
243
313
@@ -254,12 +324,6 @@ struct TwoOffsets_StridedIndexer
254
324
starting_first_offset + relative_first_offset,
255
325
starting_second_offset + relative_second_offset);
256
326
}
257
-
258
- private:
259
- int nd;
260
- py::ssize_t starting_first_offset;
261
- py::ssize_t starting_second_offset;
262
- py::ssize_t const *shape_strides;
263
327
};
264
328
265
329
struct TwoZeroOffsets_Indexer
@@ -272,12 +336,33 @@ struct TwoZeroOffsets_Indexer
272
336
}
273
337
};
274
338
339
+ template <typename FirstIndexerT, typename SecondIndexerT>
340
+ struct TwoOffsets_CombinedIndexer
341
+ {
342
+ private:
343
+ FirstIndexerT first_indexer_;
344
+ SecondIndexerT second_indexer_;
345
+
346
+ public:
347
+ TwoOffsets_CombinedIndexer (const FirstIndexerT &first_indexer,
348
+ const SecondIndexerT &second_indexer)
349
+ : first_indexer_(first_indexer), second_indexer_(second_indexer)
350
+ {
351
+ }
352
+
353
+ TwoOffsets<py::ssize_t > operator ()(py::ssize_t gid) const
354
+ {
355
+ return TwoOffsets<py::ssize_t >(first_indexer_ (gid),
356
+ second_indexer_ (gid));
357
+ }
358
+ };
359
+
275
360
template <typename displacementT> struct ThreeOffsets
276
361
{
277
362
ThreeOffsets () : first_offset(0 ), second_offset(0 ), third_offset(0 ) {}
278
- ThreeOffsets (displacementT first_offset_,
279
- displacementT second_offset_,
280
- displacementT third_offset_)
363
+ ThreeOffsets (const displacementT & first_offset_,
364
+ const displacementT & second_offset_,
365
+ const displacementT & third_offset_)
281
366
: first_offset(first_offset_), second_offset(second_offset_),
282
367
third_offset (third_offset_)
283
368
{
@@ -317,6 +402,23 @@ struct ThreeOffsets_StridedIndexer
317
402
}
318
403
319
404
ThreeOffsets<py::ssize_t > operator ()(py::ssize_t gid) const
405
+ {
406
+ return compute_offsets (gid);
407
+ }
408
+
409
+ ThreeOffsets<py::ssize_t > operator ()(size_t gid) const
410
+ {
411
+ return compute_offsets (static_cast <py::ssize_t >(gid));
412
+ }
413
+
414
+ private:
415
+ int nd;
416
+ py::ssize_t starting_first_offset;
417
+ py::ssize_t starting_second_offset;
418
+ py::ssize_t starting_third_offset;
419
+ py::ssize_t const *shape_strides;
420
+
421
+ ThreeOffsets<py::ssize_t > compute_offsets (py::ssize_t gid) const
320
422
{
321
423
using dpctl::tensor::strides::CIndexer_vector;
322
424
@@ -337,13 +439,6 @@ struct ThreeOffsets_StridedIndexer
337
439
starting_second_offset + relative_second_offset,
338
440
starting_third_offset + relative_third_offset);
339
441
}
340
-
341
- private:
342
- int nd;
343
- py::ssize_t starting_first_offset;
344
- py::ssize_t starting_second_offset;
345
- py::ssize_t starting_third_offset;
346
- py::ssize_t const *shape_strides;
347
442
};
348
443
349
444
struct ThreeZeroOffsets_Indexer
0 commit comments