Skip to content

Commit 006a05a

Browse files
Moved functions in strided_iters into dpctl::tensor::strides namespace
1 parent bcd2c98 commit 006a05a

File tree

6 files changed

+42
-16
lines changed

6 files changed

+42
-16
lines changed

dpctl/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -609,11 +609,11 @@ sycl::event tri_impl(sycl::queue exec_q,
609609
py::ssize_t outer_gid = idx[0] / inner_range;
610610
py::ssize_t inner_gid = idx[0] - inner_range * outer_gid;
611611

612-
py::ssize_t src_inner_offset, dst_inner_offset;
613-
bool to_copy;
612+
py::ssize_t src_inner_offset = 0, dst_inner_offset = 0;
613+
bool to_copy(true);
614614

615615
{
616-
// py::ssize_t inner_gid = idx.get_id(0);
616+
using dpctl::tensor::strides::CIndexer_array;
617617
CIndexer_array<d2, py::ssize_t> indexer_i(
618618
{shape_and_strides[nd_2], shape_and_strides[nd_1]});
619619
indexer_i.set(inner_gid);
@@ -634,7 +634,7 @@ sycl::event tri_impl(sycl::queue exec_q,
634634
py::ssize_t src_offset = 0;
635635
py::ssize_t dst_offset = 0;
636636
{
637-
// py::ssize_t outer_gid = idx.get_id(1);
637+
using dpctl::tensor::strides::CIndexer_vector;
638638
CIndexer_vector<py::ssize_t> outer(nd - d2);
639639
outer.get_displacement(
640640
outer_gid, shape_and_strides, shape_and_strides + src_s,

dpctl/tensor/libtensor/include/utils/offset_utils.hpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ struct StridedIndexer
141141

142142
size_t operator()(size_t gid) const
143143
{
144+
using dpctl::tensor::strides::CIndexer_vector;
145+
144146
CIndexer_vector _ind(nd);
145147
py::ssize_t relative_offset(0);
146148
_ind.get_displacement<const py::ssize_t *, const py::ssize_t *>(
@@ -233,6 +235,8 @@ struct TwoOffsets_StridedIndexer
233235

234236
TwoOffsets<py::ssize_t> operator()(py::ssize_t gid) const
235237
{
238+
using dpctl::tensor::strides::CIndexer_vector;
239+
236240
CIndexer_vector _ind(nd);
237241
py::ssize_t relative_first_offset(0);
238242
py::ssize_t relative_second_offset(0);
@@ -285,7 +289,7 @@ struct NthStrideOffset
285289
}
286290

287291
private:
288-
CIndexer_vector<py::ssize_t> _ind;
292+
dpctl::tensor::strides::CIndexer_vector<py::ssize_t> _ind;
289293

290294
int nd;
291295
py::ssize_t const *offsets;
@@ -302,11 +306,13 @@ template <int nd> struct FixedDimStridedIndexer
302306
}
303307
size_t operator()(size_t gid) const
304308
{
309+
dpctl::tensor::strides::CIndexer_array<nd, py::ssize_t> local_indexer(
310+
std::move(_ind));
311+
local_indexer.set(gid);
312+
auto mi = local_indexer.get();
313+
305314
py::ssize_t relative_offset = 0;
306-
CIndexer_array<nd, py::ssize_t> local_indxr(std::move(_ind));
307315

308-
local_indxr.set(gid);
309-
auto mi = local_indxr.get();
310316
#pragma unroll
311317
for (int i = 0; i < nd; ++i) {
312318
relative_offset += mi[i] * strides[i];
@@ -315,7 +321,7 @@ template <int nd> struct FixedDimStridedIndexer
315321
}
316322

317323
private:
318-
CIndexer_array<nd, py::ssize_t> _ind;
324+
dpctl::tensor::strides::CIndexer_array<nd, py::ssize_t> _ind;
319325

320326
const std::array<py::ssize_t, nd> strides;
321327
py::ssize_t starting_offset;
@@ -336,26 +342,29 @@ template <int nd> struct TwoOffsets_FixedDimStridedIndexer
336342

337343
TwoOffsets<py::ssize_t> operator()(size_t gid) const
338344
{
339-
py::ssize_t relative_offset1 = 0;
340-
py::ssize_t relative_offset2 = 0;
345+
dpctl::tensor::strides::CIndexer_array<nd, py::ssize_t> local_indexer(
346+
std::move(_ind));
347+
local_indexer.set(gid);
348+
auto mi = local_indexer.get();
341349

342-
CIndexer_array<nd, py::ssize_t> local_indxr(std::move(_ind));
343-
local_indxr.set(gid);
344-
auto mi = local_indxr.get();
350+
py::ssize_t relative_offset1 = 0;
345351
#pragma unroll
346352
for (int i = 0; i < nd; ++i) {
347353
relative_offset1 += mi[i] * strides1[i];
348354
}
355+
356+
py::ssize_t relative_offset2 = 0;
349357
#pragma unroll
350358
for (int i = 0; i < nd; ++i) {
351359
relative_offset2 += mi[i] * strides2[i];
352360
}
361+
353362
return TwoOffsets<py::ssize_t>(starting_offset1 + relative_offset1,
354363
starting_offset2 + relative_offset2);
355364
}
356365

357366
private:
358-
CIndexer_array<nd, py::ssize_t> _ind;
367+
dpctl::tensor::strides::CIndexer_array<nd, py::ssize_t> _ind;
359368

360369
const std::array<py::ssize_t, nd> strides1;
361370
const std::array<py::ssize_t, nd> strides2;

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@
3232
#include <tuple>
3333
#include <vector>
3434

35+
namespace dpctl
36+
{
37+
namespace tensor
38+
{
39+
namespace strides
40+
{
41+
3542
/* An N-dimensional array can be stored in a single
3643
* contiguous chunk of memory by contiguously laying
3744
* array elements in lexicographinc order of their
@@ -708,3 +715,7 @@ contract_iter3(vecT shape, vecT strides1, vecT strides2, vecT strides3)
708715
return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2,
709716
out_strides3, disp3);
710717
}
718+
719+
} // namespace strides
720+
} // namespace tensor
721+
} // namespace dpctl

dpctl/tensor/libtensor/source/simplify_iteration_space.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
#include "simplify_iteration_space.hpp"
2626
#include "dpctl4pybind11.hpp"
27+
#include "utils/strided_iters.hpp"
2728
#include <pybind11/pybind11.h>
2829
#include <vector>
2930

@@ -49,6 +50,7 @@ void simplify_iteration_space_1(int &nd,
4950
std::vector<py::ssize_t> &simplified_strides,
5051
py::ssize_t &offset)
5152
{
53+
using dpctl::tensor::strides::simplify_iteration_stride;
5254
if (nd > 1) {
5355
// Simplify iteration space to reduce dimensionality
5456
// and improve access pattern
@@ -135,6 +137,7 @@ void simplify_iteration_space(int &nd,
135137
py::ssize_t &src_offset,
136138
py::ssize_t &dst_offset)
137139
{
140+
using dpctl::tensor::strides::simplify_iteration_two_strides;
138141
if (nd > 1) {
139142
// Simplify iteration space to reduce dimensionality
140143
// and improve access pattern
@@ -280,6 +283,7 @@ void simplify_iteration_space_3(
280283
py::ssize_t &src2_offset,
281284
py::ssize_t &dst_offset)
282285
{
286+
using dpctl::tensor::strides::simplify_iteration_three_strides;
283287
if (nd > 1) {
284288
// Simplify iteration space to reduce dimensionality
285289
// and improve access pattern

dpctl/tensor/libtensor/source/simplify_iteration_space.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
//===--------------------------------------------------------------------===//
2424

2525
#pragma once
26-
#include "utils/strided_iters.hpp"
2726
#include <pybind11/pybind11.h>
2827
#include <vector>
2928

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ PYBIND11_MODULE(_tensor_impl, m)
128128
init_dispatch_tables();
129129
init_dispatch_vectors();
130130

131+
using dpctl::tensor::strides::contract_iter;
131132
m.def(
132133
"_contract_iter", &contract_iter<py::ssize_t, py::value_error>,
133134
"Simplifies iteration of array of given shape & stride. Returns "
@@ -143,6 +144,7 @@ PYBIND11_MODULE(_tensor_impl, m)
143144
py::arg("src"), py::arg("dst"), py::arg("sycl_queue"),
144145
py::arg("depends") = py::list());
145146

147+
using dpctl::tensor::strides::contract_iter2;
146148
m.def(
147149
"_contract_iter2", &contract_iter2<py::ssize_t, py::value_error>,
148150
"Simplifies iteration over elements of pair of arrays of given shape "
@@ -152,6 +154,7 @@ PYBIND11_MODULE(_tensor_impl, m)
152154
"as the original "
153155
"iterator, possibly in a different order.");
154156

157+
using dpctl::tensor::strides::contract_iter3;
155158
m.def(
156159
"_contract_iter3", &contract_iter3<py::ssize_t, py::value_error>,
157160
"Simplifies iteration over elements of 3-tuple of arrays of given "

0 commit comments

Comments
 (0)