Skip to content

Commit 98c6651

Browse files
committed
Implements contract_iter4
- Matches other contract_iter functions
1 parent c5d4959 commit 98c6651

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,42 @@ int simplify_iteration_four_strides(const int nd,
872872
return nd_;
873873
}
874874

875+
template <typename T, class Error, typename vecT = std::vector<T>>
876+
std::tuple<vecT, vecT, T, vecT, T, vecT, T, vecT, T>
877+
contract_iter4(vecT shape,
878+
vecT strides1,
879+
vecT strides2,
880+
vecT strides3,
881+
vecT strides4)
882+
{
883+
const size_t dim = shape.size();
884+
if (dim != strides1.size() || dim != strides2.size() ||
885+
dim != strides3.size() || dim != strides4.size())
886+
{
887+
throw Error("Shape and strides must be of equal size.");
888+
}
889+
vecT out_shape = shape;
890+
vecT out_strides1 = strides1;
891+
vecT out_strides2 = strides2;
892+
vecT out_strides3 = strides3;
893+
vecT out_strides4 = strides4;
894+
T disp1(0);
895+
T disp2(0);
896+
T disp3(0);
897+
T disp4(0);
898+
899+
int nd = simplify_iteration_four_strides(
900+
dim, out_shape.data(), out_strides1.data(), out_strides2.data(),
901+
out_strides3.data(), out_strides4.data(), disp1, disp2, disp3, disp4);
902+
out_shape.resize(nd);
903+
out_strides1.resize(nd);
904+
out_strides2.resize(nd);
905+
out_strides3.resize(nd);
906+
out_strides4.resize(nd);
907+
return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2,
908+
out_strides3, disp3, out_strides4, disp4);
909+
}
910+
875911
} // namespace strides
876912
} // namespace tensor
877913
} // namespace dpctl

dpctl/tensor/libtensor/source/simplify_iteration_space.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ void simplify_iteration_space_4(
674674
assert(simplified_shape.size() == static_cast<size_t>(nd));
675675
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
676676
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
677-
assert(simplified_src3_strides.size() == static_cast < size_t(nd));
677+
assert(simplified_src3_strides.size() == static_cast<size_t>(nd));
678678
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
679679
}
680680
shape = const_cast<const py::ssize_t *>(simplified_shape.data());

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,17 @@ PYBIND11_MODULE(_tensor_impl, m)
171171
"as the original "
172172
"iterator, possibly in a different order.");
173173

174+
using dpctl::tensor::strides::contract_iter4;
175+
m.def(
176+
"_contract_iter4", &contract_iter4<py::ssize_t, py::value_error>,
177+
"Simplifies iteration over elements of 4-tuple of arrays of given "
178+
"shape "
179+
"with strides stride1, stride2, stride3, and stride4. Returns "
180+
"a 9-tuple: shape, stride and offset for the new iterator of possible "
181+
"smaller dimension for each array, which traverses the same elements "
182+
"as the original "
183+
"iterator, possibly in a different order.");
184+
174185
m.def("_copy_usm_ndarray_for_reshape", &copy_usm_ndarray_for_reshape,
175186
"Copies from usm_ndarray `src` into usm_ndarray `dst` with the same "
176187
"number of elements using underlying 'C'-contiguous order for flat "

0 commit comments

Comments
 (0)