Skip to content

Commit 6998a8d

Browse files
Merge pull request #1162 from IntelPython/add-combined-indexer
Add TwoOffsets_CombinedIndexer, and UnpackedStridedIndexer
2 parents 2320c92 + b61dd32 commit 6998a8d

File tree

1 file changed

+114
-19
lines changed

1 file changed

+114
-19
lines changed

dpctl/tensor/libtensor/include/utils/offset_utils.hpp

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ struct NoOpIndexer
133133
}
134134
};
135135

136+
/* @brief Indexer with shape and strides arrays of same size are packed */
136137
struct StridedIndexer
137138
{
138139
StridedIndexer(int _nd,
@@ -143,24 +144,76 @@ struct StridedIndexer
143144
{
144145
}
145146

147+
size_t operator()(py::ssize_t gid) const
148+
{
149+
return compute_offset(gid);
150+
}
151+
146152
size_t operator()(size_t gid) const
153+
{
154+
return compute_offset(static_cast<py::ssize_t>(gid));
155+
}
156+
157+
private:
158+
int nd;
159+
py::ssize_t starting_offset;
160+
py::ssize_t const *shape_strides;
161+
162+
size_t compute_offset(py::ssize_t gid) const
147163
{
148164
using dpctl::tensor::strides::CIndexer_vector;
149165

150166
CIndexer_vector _ind(nd);
151167
py::ssize_t relative_offset(0);
152168
_ind.get_displacement<const py::ssize_t *, const py::ssize_t *>(
153-
static_cast<py::ssize_t>(gid),
169+
gid,
154170
shape_strides, // shape ptr
155171
shape_strides + nd, // strides ptr
156172
relative_offset);
157173
return starting_offset + relative_offset;
158174
}
175+
};
176+
177+
/* @brief Indexer with shape, strides provided separately */
178+
struct UnpackedStridedIndexer
179+
{
180+
UnpackedStridedIndexer(int _nd,
181+
py::ssize_t _offset,
182+
py::ssize_t const *_shape,
183+
py::ssize_t const *_strides)
184+
: nd(_nd), starting_offset(_offset), shape(_shape), strides(_strides)
185+
{
186+
}
187+
188+
size_t operator()(py::ssize_t gid) const
189+
{
190+
return compute_offset(gid);
191+
}
192+
193+
size_t operator()(size_t gid) const
194+
{
195+
return compute_offset(static_cast<py::ssize_t>(gid));
196+
}
159197

160198
private:
161199
int nd;
162200
py::ssize_t starting_offset;
163-
py::ssize_t const *shape_strides;
201+
py::ssize_t const *shape;
202+
py::ssize_t const *strides;
203+
204+
size_t compute_offset(py::ssize_t gid) const
205+
{
206+
using dpctl::tensor::strides::CIndexer_vector;
207+
208+
CIndexer_vector _ind(nd);
209+
py::ssize_t relative_offset(0);
210+
_ind.get_displacement<const py::ssize_t *, const py::ssize_t *>(
211+
gid,
212+
shape, // shape ptr
213+
strides, // strides ptr
214+
relative_offset);
215+
return starting_offset + relative_offset;
216+
}
164217
};
165218

166219
struct Strided1DIndexer
@@ -206,7 +259,8 @@ struct Strided1DCyclicIndexer
206259
template <typename displacementT> struct TwoOffsets
207260
{
208261
TwoOffsets() : first_offset(0), second_offset(0) {}
209-
TwoOffsets(displacementT first_offset_, displacementT second_offset_)
262+
TwoOffsets(const displacementT &first_offset_,
263+
const displacementT &second_offset_)
210264
: first_offset(first_offset_), second_offset(second_offset_)
211265
{
212266
}
@@ -238,6 +292,22 @@ struct TwoOffsets_StridedIndexer
238292
}
239293

240294
TwoOffsets<py::ssize_t> operator()(py::ssize_t gid) const
295+
{
296+
return compute_offsets(gid);
297+
}
298+
299+
TwoOffsets<py::ssize_t> operator()(size_t gid) const
300+
{
301+
return compute_offsets(static_cast<py::ssize_t>(gid));
302+
}
303+
304+
private:
305+
int nd;
306+
py::ssize_t starting_first_offset;
307+
py::ssize_t starting_second_offset;
308+
py::ssize_t const *shape_strides;
309+
310+
TwoOffsets<py::ssize_t> compute_offsets(py::ssize_t gid) const
241311
{
242312
using dpctl::tensor::strides::CIndexer_vector;
243313

@@ -254,12 +324,6 @@ struct TwoOffsets_StridedIndexer
254324
starting_first_offset + relative_first_offset,
255325
starting_second_offset + relative_second_offset);
256326
}
257-
258-
private:
259-
int nd;
260-
py::ssize_t starting_first_offset;
261-
py::ssize_t starting_second_offset;
262-
py::ssize_t const *shape_strides;
263327
};
264328

265329
struct TwoZeroOffsets_Indexer
@@ -272,12 +336,33 @@ struct TwoZeroOffsets_Indexer
272336
}
273337
};
274338

339+
template <typename FirstIndexerT, typename SecondIndexerT>
340+
struct TwoOffsets_CombinedIndexer
341+
{
342+
private:
343+
FirstIndexerT first_indexer_;
344+
SecondIndexerT second_indexer_;
345+
346+
public:
347+
TwoOffsets_CombinedIndexer(const FirstIndexerT &first_indexer,
348+
const SecondIndexerT &second_indexer)
349+
: first_indexer_(first_indexer), second_indexer_(second_indexer)
350+
{
351+
}
352+
353+
TwoOffsets<py::ssize_t> operator()(py::ssize_t gid) const
354+
{
355+
return TwoOffsets<py::ssize_t>(first_indexer_(gid),
356+
second_indexer_(gid));
357+
}
358+
};
359+
275360
template <typename displacementT> struct ThreeOffsets
276361
{
277362
ThreeOffsets() : first_offset(0), second_offset(0), third_offset(0) {}
278-
ThreeOffsets(displacementT first_offset_,
279-
displacementT second_offset_,
280-
displacementT third_offset_)
363+
ThreeOffsets(const displacementT &first_offset_,
364+
const displacementT &second_offset_,
365+
const displacementT &third_offset_)
281366
: first_offset(first_offset_), second_offset(second_offset_),
282367
third_offset(third_offset_)
283368
{
@@ -317,6 +402,23 @@ struct ThreeOffsets_StridedIndexer
317402
}
318403

319404
ThreeOffsets<py::ssize_t> operator()(py::ssize_t gid) const
405+
{
406+
return compute_offsets(gid);
407+
}
408+
409+
ThreeOffsets<py::ssize_t> operator()(size_t gid) const
410+
{
411+
return compute_offsets(static_cast<py::ssize_t>(gid));
412+
}
413+
414+
private:
415+
int nd;
416+
py::ssize_t starting_first_offset;
417+
py::ssize_t starting_second_offset;
418+
py::ssize_t starting_third_offset;
419+
py::ssize_t const *shape_strides;
420+
421+
ThreeOffsets<py::ssize_t> compute_offsets(py::ssize_t gid) const
320422
{
321423
using dpctl::tensor::strides::CIndexer_vector;
322424

@@ -337,13 +439,6 @@ struct ThreeOffsets_StridedIndexer
337439
starting_second_offset + relative_second_offset,
338440
starting_third_offset + relative_third_offset);
339441
}
340-
341-
private:
342-
int nd;
343-
py::ssize_t starting_first_offset;
344-
py::ssize_t starting_second_offset;
345-
py::ssize_t starting_third_offset;
346-
py::ssize_t const *shape_strides;
347442
};
348443

349444
struct ThreeZeroOffsets_Indexer

0 commit comments

Comments
 (0)