Skip to content

Commit d945d95

Browse files
authored
Merge pull request #1150 from IntelPython/add-threeoffsets-indexer
Implements indexer for handling three offsets
2 parents df0bf84 + 0f32e41 commit d945d95

File tree

1 file changed

+86
-2
lines changed

1 file changed

+86
-2
lines changed

dpctl/tensor/libtensor/include/utils/offset_utils.hpp

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ struct TwoOffsets_StridedIndexer
243243
_ind.get_displacement<const py::ssize_t *, const py::ssize_t *>(
244244
gid,
245245
shape_strides, // shape ptr
246-
shape_strides + nd, // src strides ptr
247-
shape_strides + 2 * nd, // src strides ptr
246+
shape_strides + nd, // strides ptr
247+
shape_strides + 2 * nd, // strides ptr
248248
relative_first_offset, relative_second_offset);
249249
return TwoOffsets<py::ssize_t>(
250250
starting_first_offset + relative_first_offset,
@@ -268,6 +268,90 @@ struct TwoZeroOffsets_Indexer
268268
}
269269
};
270270

271+
template <typename displacementT> struct ThreeOffsets
272+
{
273+
ThreeOffsets() : first_offset(0), second_offset(0), third_offset(0) {}
274+
ThreeOffsets(displacementT first_offset_,
275+
displacementT second_offset_,
276+
displacementT third_offset_)
277+
: first_offset(first_offset_), second_offset(second_offset_),
278+
third_offset(third_offset_)
279+
{
280+
}
281+
282+
displacementT get_first_offset() const
283+
{
284+
return first_offset;
285+
}
286+
displacementT get_second_offset() const
287+
{
288+
return second_offset;
289+
}
290+
displacementT get_third_offset() const
291+
{
292+
return third_offset;
293+
}
294+
295+
private:
296+
displacementT first_offset = 0;
297+
displacementT second_offset = 0;
298+
displacementT third_offset = 0;
299+
};
300+
301+
struct ThreeOffsets_StridedIndexer
302+
{
303+
ThreeOffsets_StridedIndexer(int common_nd,
304+
py::ssize_t first_offset_,
305+
py::ssize_t second_offset_,
306+
py::ssize_t third_offset_,
307+
py::ssize_t const *_packed_shape_strides)
308+
: nd(common_nd), starting_first_offset(first_offset_),
309+
starting_second_offset(second_offset_),
310+
starting_third_offset(third_offset_),
311+
shape_strides(_packed_shape_strides)
312+
{
313+
}
314+
315+
ThreeOffsets<py::ssize_t> operator()(py::ssize_t gid) const
316+
{
317+
using dpctl::tensor::strides::CIndexer_vector;
318+
319+
CIndexer_vector _ind(nd);
320+
py::ssize_t relative_first_offset(0);
321+
py::ssize_t relative_second_offset(0);
322+
py::ssize_t relative_third_offset(0);
323+
_ind.get_displacement<const py::ssize_t *, const py::ssize_t *>(
324+
gid,
325+
shape_strides, // shape ptr
326+
shape_strides + nd, // strides ptr
327+
shape_strides + 2 * nd, // strides ptr
328+
shape_strides + 3 * nd, // strides ptr
329+
relative_first_offset, relative_second_offset,
330+
relative_third_offset);
331+
return ThreeOffsets<py::ssize_t>(
332+
starting_first_offset + relative_first_offset,
333+
starting_second_offset + relative_second_offset,
334+
starting_third_offset + relative_third_offset);
335+
}
336+
337+
private:
338+
int nd;
339+
py::ssize_t starting_first_offset;
340+
py::ssize_t starting_second_offset;
341+
py::ssize_t starting_third_offset;
342+
py::ssize_t const *shape_strides;
343+
};
344+
345+
struct ThreeZeroOffsets_Indexer
346+
{
347+
ThreeZeroOffsets_Indexer() {}
348+
349+
ThreeOffsets<py::ssize_t> operator()(py::ssize_t) const
350+
{
351+
return ThreeOffsets<py::ssize_t>();
352+
}
353+
};
354+
271355
struct NthStrideOffset
272356
{
273357
NthStrideOffset(int common_nd,

0 commit comments

Comments
 (0)