@@ -243,8 +243,8 @@ struct TwoOffsets_StridedIndexer
243
243
_ind.get_displacement <const py::ssize_t *, const py::ssize_t *>(
244
244
gid,
245
245
shape_strides, // shape ptr
246
- shape_strides + nd, // src strides ptr
247
- shape_strides + 2 * nd, // src strides ptr
246
+ shape_strides + nd, // strides ptr
247
+ shape_strides + 2 * nd, // strides ptr
248
248
relative_first_offset, relative_second_offset);
249
249
return TwoOffsets<py::ssize_t >(
250
250
starting_first_offset + relative_first_offset,
@@ -268,6 +268,90 @@ struct TwoZeroOffsets_Indexer
268
268
}
269
269
};
270
270
271
+ template <typename displacementT> struct ThreeOffsets
272
+ {
273
+ ThreeOffsets () : first_offset(0 ), second_offset(0 ), third_offset(0 ) {}
274
+ ThreeOffsets (displacementT first_offset_,
275
+ displacementT second_offset_,
276
+ displacementT third_offset_)
277
+ : first_offset(first_offset_), second_offset(second_offset_),
278
+ third_offset (third_offset_)
279
+ {
280
+ }
281
+
282
+ displacementT get_first_offset () const
283
+ {
284
+ return first_offset;
285
+ }
286
+ displacementT get_second_offset () const
287
+ {
288
+ return second_offset;
289
+ }
290
+ displacementT get_third_offset () const
291
+ {
292
+ return third_offset;
293
+ }
294
+
295
+ private:
296
+ displacementT first_offset = 0 ;
297
+ displacementT second_offset = 0 ;
298
+ displacementT third_offset = 0 ;
299
+ };
300
+
301
+ struct ThreeOffsets_StridedIndexer
302
+ {
303
+ ThreeOffsets_StridedIndexer (int common_nd,
304
+ py::ssize_t first_offset_,
305
+ py::ssize_t second_offset_,
306
+ py::ssize_t third_offset_,
307
+ py::ssize_t const *_packed_shape_strides)
308
+ : nd(common_nd), starting_first_offset(first_offset_),
309
+ starting_second_offset (second_offset_),
310
+ starting_third_offset(third_offset_),
311
+ shape_strides(_packed_shape_strides)
312
+ {
313
+ }
314
+
315
+ ThreeOffsets<py::ssize_t > operator ()(py::ssize_t gid) const
316
+ {
317
+ using dpctl::tensor::strides::CIndexer_vector;
318
+
319
+ CIndexer_vector _ind (nd);
320
+ py::ssize_t relative_first_offset (0 );
321
+ py::ssize_t relative_second_offset (0 );
322
+ py::ssize_t relative_third_offset (0 );
323
+ _ind.get_displacement <const py::ssize_t *, const py::ssize_t *>(
324
+ gid,
325
+ shape_strides, // shape ptr
326
+ shape_strides + nd, // strides ptr
327
+ shape_strides + 2 * nd, // strides ptr
328
+ shape_strides + 3 * nd, // strides ptr
329
+ relative_first_offset, relative_second_offset,
330
+ relative_third_offset);
331
+ return ThreeOffsets<py::ssize_t >(
332
+ starting_first_offset + relative_first_offset,
333
+ starting_second_offset + relative_second_offset,
334
+ starting_third_offset + relative_third_offset);
335
+ }
336
+
337
+ private:
338
+ int nd;
339
+ py::ssize_t starting_first_offset;
340
+ py::ssize_t starting_second_offset;
341
+ py::ssize_t starting_third_offset;
342
+ py::ssize_t const *shape_strides;
343
+ };
344
+
345
+ struct ThreeZeroOffsets_Indexer
346
+ {
347
+ ThreeZeroOffsets_Indexer () {}
348
+
349
+ ThreeOffsets<py::ssize_t > operator ()(py::ssize_t ) const
350
+ {
351
+ return ThreeOffsets<py::ssize_t >();
352
+ }
353
+ };
354
+
271
355
struct NthStrideOffset
272
356
{
273
357
NthStrideOffset (int common_nd,
0 commit comments