Skip to content

Commit b61dd32

Browse files
Add TwoOffsets_CombinedIndexer, and UnpackedStridedIndexer
The TwoOffsets_CombinedIndexer takes two single offset indexers and combines them into a TwoOffsets struct. The UnpackedStridedIndexer is a relative of StridedIndexer, except shapes and strides are provided as separate pointers. This increases its size, but may be useful to produce indexer that only computes the second offset using shape/strides stored in packed format for two-offset indexers. These are going to be used in reduction kernels.
1 parent 7bbfce1 commit b61dd32

File tree

1 file changed

+114
-19
lines changed

1 file changed

+114
-19
lines changed

dpctl/tensor/libtensor/include/utils/offset_utils.hpp

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ struct NoOpIndexer
133133
}
134134
};
135135

136+
/* @brief Indexer with shape and strides arrays of same size are packed */
136137
struct StridedIndexer
137138
{
138139
StridedIndexer(int _nd,
@@ -143,24 +144,76 @@ struct StridedIndexer
143144
{
144145
}
145146

147+
size_t operator()(py::ssize_t gid) const
148+
{
149+
return compute_offset(gid);
150+
}
151+
146152
size_t operator()(size_t gid) const
153+
{
154+
return compute_offset(static_cast<py::ssize_t>(gid));
155+
}
156+
157+
private:
158+
int nd;
159+
py::ssize_t starting_offset;
160+
py::ssize_t const *shape_strides;
161+
162+
size_t compute_offset(py::ssize_t gid) const
147163
{
148164
using dpctl::tensor::strides::CIndexer_vector;
149165

150166
CIndexer_vector _ind(nd);
151167
py::ssize_t relative_offset(0);
152168
_ind.get_displacement<const py::ssize_t *, const py::ssize_t *>(
153-
static_cast<py::ssize_t>(gid),
169+
gid,
154170
shape_strides, // shape ptr
155171
shape_strides + nd, // strides ptr
156172
relative_offset);
157173
return starting_offset + relative_offset;
158174
}
175+
};
176+
177+
/* @brief Indexer with shape, strides provided separately */
178+
struct UnpackedStridedIndexer
179+
{
180+
UnpackedStridedIndexer(int _nd,
181+
py::ssize_t _offset,
182+
py::ssize_t const *_shape,
183+
py::ssize_t const *_strides)
184+
: nd(_nd), starting_offset(_offset), shape(_shape), strides(_strides)
185+
{
186+
}
187+
188+
size_t operator()(py::ssize_t gid) const
189+
{
190+
return compute_offset(gid);
191+
}
192+
193+
size_t operator()(size_t gid) const
194+
{
195+
return compute_offset(static_cast<py::ssize_t>(gid));
196+
}
159197

160198
private:
161199
int nd;
162200
py::ssize_t starting_offset;
163-
py::ssize_t const *shape_strides;
201+
py::ssize_t const *shape;
202+
py::ssize_t const *strides;
203+
204+
size_t compute_offset(py::ssize_t gid) const
205+
{
206+
using dpctl::tensor::strides::CIndexer_vector;
207+
208+
CIndexer_vector _ind(nd);
209+
py::ssize_t relative_offset(0);
210+
_ind.get_displacement<const py::ssize_t *, const py::ssize_t *>(
211+
gid,
212+
shape, // shape ptr
213+
strides, // strides ptr
214+
relative_offset);
215+
return starting_offset + relative_offset;
216+
}
164217
};
165218

166219
struct Strided1DIndexer
@@ -206,7 +259,8 @@ struct Strided1DCyclicIndexer
206259
template <typename displacementT> struct TwoOffsets
207260
{
208261
TwoOffsets() : first_offset(0), second_offset(0) {}
209-
TwoOffsets(displacementT first_offset_, displacementT second_offset_)
262+
TwoOffsets(const displacementT &first_offset_,
263+
const displacementT &second_offset_)
210264
: first_offset(first_offset_), second_offset(second_offset_)
211265
{
212266
}
@@ -238,6 +292,22 @@ struct TwoOffsets_StridedIndexer
238292
}
239293

240294
TwoOffsets<py::ssize_t> operator()(py::ssize_t gid) const
295+
{
296+
return compute_offsets(gid);
297+
}
298+
299+
TwoOffsets<py::ssize_t> operator()(size_t gid) const
300+
{
301+
return compute_offsets(static_cast<py::ssize_t>(gid));
302+
}
303+
304+
private:
305+
int nd;
306+
py::ssize_t starting_first_offset;
307+
py::ssize_t starting_second_offset;
308+
py::ssize_t const *shape_strides;
309+
310+
TwoOffsets<py::ssize_t> compute_offsets(py::ssize_t gid) const
241311
{
242312
using dpctl::tensor::strides::CIndexer_vector;
243313

@@ -254,12 +324,6 @@ struct TwoOffsets_StridedIndexer
254324
starting_first_offset + relative_first_offset,
255325
starting_second_offset + relative_second_offset);
256326
}
257-
258-
private:
259-
int nd;
260-
py::ssize_t starting_first_offset;
261-
py::ssize_t starting_second_offset;
262-
py::ssize_t const *shape_strides;
263327
};
264328

265329
struct TwoZeroOffsets_Indexer
@@ -272,12 +336,33 @@ struct TwoZeroOffsets_Indexer
272336
}
273337
};
274338

339+
template <typename FirstIndexerT, typename SecondIndexerT>
340+
struct TwoOffsets_CombinedIndexer
341+
{
342+
private:
343+
FirstIndexerT first_indexer_;
344+
SecondIndexerT second_indexer_;
345+
346+
public:
347+
TwoOffsets_CombinedIndexer(const FirstIndexerT &first_indexer,
348+
const SecondIndexerT &second_indexer)
349+
: first_indexer_(first_indexer), second_indexer_(second_indexer)
350+
{
351+
}
352+
353+
TwoOffsets<py::ssize_t> operator()(py::ssize_t gid) const
354+
{
355+
return TwoOffsets<py::ssize_t>(first_indexer_(gid),
356+
second_indexer_(gid));
357+
}
358+
};
359+
275360
template <typename displacementT> struct ThreeOffsets
276361
{
277362
ThreeOffsets() : first_offset(0), second_offset(0), third_offset(0) {}
278-
ThreeOffsets(displacementT first_offset_,
279-
displacementT second_offset_,
280-
displacementT third_offset_)
363+
ThreeOffsets(const displacementT &first_offset_,
364+
const displacementT &second_offset_,
365+
const displacementT &third_offset_)
281366
: first_offset(first_offset_), second_offset(second_offset_),
282367
third_offset(third_offset_)
283368
{
@@ -317,6 +402,23 @@ struct ThreeOffsets_StridedIndexer
317402
}
318403

319404
ThreeOffsets<py::ssize_t> operator()(py::ssize_t gid) const
405+
{
406+
return compute_offsets(gid);
407+
}
408+
409+
ThreeOffsets<py::ssize_t> operator()(size_t gid) const
410+
{
411+
return compute_offsets(static_cast<py::ssize_t>(gid));
412+
}
413+
414+
private:
415+
int nd;
416+
py::ssize_t starting_first_offset;
417+
py::ssize_t starting_second_offset;
418+
py::ssize_t starting_third_offset;
419+
py::ssize_t const *shape_strides;
420+
421+
ThreeOffsets<py::ssize_t> compute_offsets(py::ssize_t gid) const
320422
{
321423
using dpctl::tensor::strides::CIndexer_vector;
322424

@@ -337,13 +439,6 @@ struct ThreeOffsets_StridedIndexer
337439
starting_second_offset + relative_second_offset,
338440
starting_third_offset + relative_third_offset);
339441
}
340-
341-
private:
342-
int nd;
343-
py::ssize_t starting_first_offset;
344-
py::ssize_t starting_second_offset;
345-
py::ssize_t starting_third_offset;
346-
py::ssize_t const *shape_strides;
347442
};
348443

349444
struct ThreeZeroOffsets_Indexer

0 commit comments

Comments
 (0)