Skip to content

Commit 456f46f

Browse files
Merge pull request #1171 from IntelPython/where-fix-gh-1170
Fixes incorrect output in dpctl.tensor.where strided implementation
2 parents 02d1f94 + 0fe5ac7 commit 456f46f

File tree

8 files changed

+662
-18
lines changed

8 files changed

+662
-18
lines changed

dpctl/tensor/libtensor/include/kernels/where.hpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,9 @@ class WhereStridedFunctor
210210
bool check =
211211
convert_impl<bool, condT>(cond_p[offsets.get_first_offset()]);
212212

213-
dst_p[gid] = check ? x1_p[offsets.get_second_offset()]
214-
: x2_p[offsets.get_third_offset()];
213+
dst_p[offsets.get_fourth_offset()] =
214+
check ? x1_p[offsets.get_second_offset()]
215+
: x2_p[offsets.get_third_offset()];
215216
}
216217
};
217218

@@ -227,6 +228,7 @@ typedef sycl::event (*where_strided_impl_fn_ptr_t)(
227228
py::ssize_t,
228229
py::ssize_t,
229230
py::ssize_t,
231+
py::ssize_t,
230232
const std::vector<sycl::event> &);
231233

232234
template <typename T, typename condT>
@@ -241,6 +243,7 @@ sycl::event where_strided_impl(sycl::queue q,
241243
py::ssize_t x1_offset,
242244
py::ssize_t x2_offset,
243245
py::ssize_t cond_offset,
246+
py::ssize_t dst_offset,
244247
const std::vector<sycl::event> &depends)
245248
{
246249
const condT *cond_tp = reinterpret_cast<const condT *>(cond_cp);
@@ -251,13 +254,13 @@ sycl::event where_strided_impl(sycl::queue q,
251254
sycl::event where_ev = q.submit([&](sycl::handler &cgh) {
252255
cgh.depends_on(depends);
253256

254-
ThreeOffsets_StridedIndexer indexer{nd, cond_offset, x1_offset,
255-
x2_offset, shape_strides};
257+
FourOffsets_StridedIndexer indexer{
258+
nd, cond_offset, x1_offset, x2_offset, dst_offset, shape_strides};
256259

257260
cgh.parallel_for<
258-
where_strided_kernel<T, condT, ThreeOffsets_StridedIndexer>>(
261+
where_strided_kernel<T, condT, FourOffsets_StridedIndexer>>(
259262
sycl::range<1>(nelems),
260-
WhereStridedFunctor<T, condT, ThreeOffsets_StridedIndexer>(
263+
WhereStridedFunctor<T, condT, FourOffsets_StridedIndexer>(
261264
cond_tp, x1_tp, x2_tp, dst_tp, indexer));
262265
});
263266

dpctl/tensor/libtensor/include/utils/offset_utils.hpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,115 @@ struct ThreeZeroOffsets_Indexer
451451
}
452452
};
453453

454+
template <typename displacementT> struct FourOffsets
455+
{
456+
FourOffsets()
457+
: first_offset(0), second_offset(0), third_offset(0), fourth_offset(0)
458+
{
459+
}
460+
FourOffsets(const displacementT &first_offset_,
461+
const displacementT &second_offset_,
462+
const displacementT &third_offset_,
463+
const displacementT &fourth_offset_)
464+
: first_offset(first_offset_), second_offset(second_offset_),
465+
third_offset(third_offset_), fourth_offset(fourth_offset_)
466+
{
467+
}
468+
469+
displacementT get_first_offset() const
470+
{
471+
return first_offset;
472+
}
473+
displacementT get_second_offset() const
474+
{
475+
return second_offset;
476+
}
477+
displacementT get_third_offset() const
478+
{
479+
return third_offset;
480+
}
481+
displacementT get_fourth_offset() const
482+
{
483+
return fourth_offset;
484+
}
485+
486+
private:
487+
displacementT first_offset = 0;
488+
displacementT second_offset = 0;
489+
displacementT third_offset = 0;
490+
displacementT fourth_offset = 0;
491+
};
492+
493+
struct FourOffsets_StridedIndexer
494+
{
495+
FourOffsets_StridedIndexer(int common_nd,
496+
py::ssize_t first_offset_,
497+
py::ssize_t second_offset_,
498+
py::ssize_t third_offset_,
499+
py::ssize_t fourth_offset_,
500+
py::ssize_t const *_packed_shape_strides)
501+
: nd(common_nd), starting_first_offset(first_offset_),
502+
starting_second_offset(second_offset_),
503+
starting_third_offset(third_offset_),
504+
starting_fourth_offset(fourth_offset_),
505+
shape_strides(_packed_shape_strides)
506+
{
507+
}
508+
509+
FourOffsets<py::ssize_t> operator()(py::ssize_t gid) const
510+
{
511+
return compute_offsets(gid);
512+
}
513+
514+
FourOffsets<py::ssize_t> operator()(size_t gid) const
515+
{
516+
return compute_offsets(static_cast<py::ssize_t>(gid));
517+
}
518+
519+
private:
520+
int nd;
521+
py::ssize_t starting_first_offset;
522+
py::ssize_t starting_second_offset;
523+
py::ssize_t starting_third_offset;
524+
py::ssize_t starting_fourth_offset;
525+
py::ssize_t const *shape_strides;
526+
527+
FourOffsets<py::ssize_t> compute_offsets(py::ssize_t gid) const
528+
{
529+
using dpctl::tensor::strides::CIndexer_vector;
530+
531+
CIndexer_vector _ind(nd);
532+
py::ssize_t relative_first_offset(0);
533+
py::ssize_t relative_second_offset(0);
534+
py::ssize_t relative_third_offset(0);
535+
py::ssize_t relative_fourth_offset(0);
536+
_ind.get_displacement<const py::ssize_t *, const py::ssize_t *>(
537+
gid,
538+
shape_strides, // shape ptr
539+
shape_strides + nd, // strides ptr
540+
shape_strides + 2 * nd, // strides ptr
541+
shape_strides + 3 * nd, // strides ptr
542+
shape_strides + 4 * nd, // strides ptr
543+
relative_first_offset, relative_second_offset,
544+
relative_third_offset, relative_fourth_offset);
545+
return FourOffsets<py::ssize_t>(
546+
starting_first_offset + relative_first_offset,
547+
starting_second_offset + relative_second_offset,
548+
starting_third_offset + relative_third_offset,
549+
starting_fourth_offset + relative_fourth_offset);
550+
}
551+
};
552+
553+
struct FourZeroOffsets_Indexer
554+
{
555+
FourZeroOffsets_Indexer() {}
556+
557+
FourOffsets<py::ssize_t> operator()(py::ssize_t) const
558+
{
559+
return FourOffsets<py::ssize_t>();
560+
}
561+
};
562+
454563
struct NthStrideOffset
455564
{
456565
NthStrideOffset(int common_nd,

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 193 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ int simplify_iteration_three_strides(const int nd,
623623
auto str3_p = strides3[p];
624624
shape_w.push_back(sh_p);
625625
if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
626-
std::min(std::min(str1_p, str2_p), str3_p) < 0)
626+
std::min({str1_p, str2_p, str3_p}) < 0)
627627
{
628628
disp1 += str1_p * (sh_p - 1);
629629
str1_p = -str1_p;
@@ -716,6 +716,198 @@ contract_iter3(vecT shape, vecT strides1, vecT strides2, vecT strides3)
716716
out_strides3, disp3);
717717
}
718718

719+
/*
720+
For purposes of iterating over pairs of elements of four arrays
721+
with `shape` and strides `strides1`, `strides2`, `strides3`,
722+
`strides4` given as pointers `simplify_iteration_four_strides(nd,
723+
shape_ptr, strides1_ptr, strides2_ptr, strides3_ptr, strides4_ptr,
724+
disp1, disp2, disp3, disp4)` may modify memory and returns new
725+
length of these arrays.
726+
727+
The new shape and new strides, as well as the offset
728+
`(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3,
729+
new_stride4, disp4)` are such that iterating over them will traverse the
730+
same set of tuples of elements, possibly in a different order.
731+
*/
732+
template <class ShapeTy, class StridesTy>
733+
int simplify_iteration_four_strides(const int nd,
734+
ShapeTy *shape,
735+
StridesTy *strides1,
736+
StridesTy *strides2,
737+
StridesTy *strides3,
738+
StridesTy *strides4,
739+
StridesTy &disp1,
740+
StridesTy &disp2,
741+
StridesTy &disp3,
742+
StridesTy &disp4)
743+
{
744+
disp1 = std::ptrdiff_t(0);
745+
disp2 = std::ptrdiff_t(0);
746+
if (nd < 2)
747+
return nd;
748+
749+
std::vector<int> pos(nd);
750+
std::iota(pos.begin(), pos.end(), 0);
751+
752+
std::stable_sort(
753+
pos.begin(), pos.end(),
754+
[&strides1, &strides2, &strides3, &strides4, &shape](int i1, int i2) {
755+
auto abs_str1_i1 =
756+
(strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
757+
auto abs_str1_i2 =
758+
(strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
759+
auto abs_str2_i1 =
760+
(strides2[i1] < 0) ? -strides2[i1] : strides2[i1];
761+
auto abs_str2_i2 =
762+
(strides2[i2] < 0) ? -strides2[i2] : strides2[i2];
763+
auto abs_str3_i1 =
764+
(strides3[i1] < 0) ? -strides3[i1] : strides3[i1];
765+
auto abs_str3_i2 =
766+
(strides3[i2] < 0) ? -strides3[i2] : strides3[i2];
767+
auto abs_str4_i1 =
768+
(strides4[i1] < 0) ? -strides4[i1] : strides4[i1];
769+
auto abs_str4_i2 =
770+
(strides4[i2] < 0) ? -strides4[i2] : strides4[i2];
771+
return (abs_str1_i1 > abs_str1_i2) ||
772+
((abs_str1_i1 == abs_str1_i2) &&
773+
((abs_str2_i1 > abs_str2_i2) ||
774+
((abs_str2_i1 == abs_str2_i2) &&
775+
((abs_str3_i1 > abs_str3_i2) ||
776+
((abs_str3_i1 == abs_str3_i2) &&
777+
((abs_str4_i1 > abs_str4_i2) ||
778+
((abs_str4_i1 == abs_str4_i2) &&
779+
(shape[i1] > shape[i2]))))))));
780+
});
781+
782+
std::vector<ShapeTy> shape_w;
783+
std::vector<StridesTy> strides1_w;
784+
std::vector<StridesTy> strides2_w;
785+
std::vector<StridesTy> strides3_w;
786+
std::vector<StridesTy> strides4_w;
787+
788+
bool contractable = true;
789+
for (int i = 0; i < nd; ++i) {
790+
auto p = pos[i];
791+
auto sh_p = shape[p];
792+
auto str1_p = strides1[p];
793+
auto str2_p = strides2[p];
794+
auto str3_p = strides3[p];
795+
auto str4_p = strides4[p];
796+
shape_w.push_back(sh_p);
797+
if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 && str4_p <= 0 &&
798+
std::min({str1_p, str2_p, str3_p, str4_p}) < 0)
799+
{
800+
disp1 += str1_p * (sh_p - 1);
801+
str1_p = -str1_p;
802+
disp2 += str2_p * (sh_p - 1);
803+
str2_p = -str2_p;
804+
disp3 += str3_p * (sh_p - 1);
805+
str3_p = -str3_p;
806+
disp4 += str4_p * (sh_p - 1);
807+
str4_p = -str4_p;
808+
}
809+
if (str1_p < 0 || str2_p < 0 || str3_p < 0 || str4_p < 0) {
810+
contractable = false;
811+
}
812+
strides1_w.push_back(str1_p);
813+
strides2_w.push_back(str2_p);
814+
strides3_w.push_back(str3_p);
815+
strides4_w.push_back(str4_p);
816+
}
817+
int nd_ = nd;
818+
while (contractable) {
819+
bool changed = false;
820+
for (int i = 0; i + 1 < nd_; ++i) {
821+
StridesTy str1 = strides1_w[i + 1];
822+
StridesTy str2 = strides2_w[i + 1];
823+
StridesTy str3 = strides3_w[i + 1];
824+
StridesTy str4 = strides4_w[i + 1];
825+
StridesTy jump1 = strides1_w[i] - (shape_w[i + 1] - 1) * str1;
826+
StridesTy jump2 = strides2_w[i] - (shape_w[i + 1] - 1) * str2;
827+
StridesTy jump3 = strides3_w[i] - (shape_w[i + 1] - 1) * str3;
828+
StridesTy jump4 = strides4_w[i] - (shape_w[i + 1] - 1) * str4;
829+
830+
if (jump1 == str1 && jump2 == str2 && jump3 == str3 &&
831+
jump4 == str4) {
832+
changed = true;
833+
shape_w[i] *= shape_w[i + 1];
834+
for (int j = i; j < nd_; ++j) {
835+
strides1_w[j] = strides1_w[j + 1];
836+
}
837+
for (int j = i; j < nd_; ++j) {
838+
strides2_w[j] = strides2_w[j + 1];
839+
}
840+
for (int j = i; j < nd_; ++j) {
841+
strides3_w[j] = strides3_w[j + 1];
842+
}
843+
for (int j = i; j < nd_; ++j) {
844+
strides4_w[j] = strides4_w[j + 1];
845+
}
846+
for (int j = i + 1; j + 1 < nd_; ++j) {
847+
shape_w[j] = shape_w[j + 1];
848+
}
849+
--nd_;
850+
break;
851+
}
852+
}
853+
if (!changed)
854+
break;
855+
}
856+
for (int i = 0; i < nd_; ++i) {
857+
shape[i] = shape_w[i];
858+
}
859+
for (int i = 0; i < nd_; ++i) {
860+
strides1[i] = strides1_w[i];
861+
}
862+
for (int i = 0; i < nd_; ++i) {
863+
strides2[i] = strides2_w[i];
864+
}
865+
for (int i = 0; i < nd_; ++i) {
866+
strides3[i] = strides3_w[i];
867+
}
868+
for (int i = 0; i < nd_; ++i) {
869+
strides4[i] = strides4_w[i];
870+
}
871+
872+
return nd_;
873+
}
874+
875+
template <typename T, class Error, typename vecT = std::vector<T>>
876+
std::tuple<vecT, vecT, T, vecT, T, vecT, T, vecT, T>
877+
contract_iter4(vecT shape,
878+
vecT strides1,
879+
vecT strides2,
880+
vecT strides3,
881+
vecT strides4)
882+
{
883+
const size_t dim = shape.size();
884+
if (dim != strides1.size() || dim != strides2.size() ||
885+
dim != strides3.size() || dim != strides4.size())
886+
{
887+
throw Error("Shape and strides must be of equal size.");
888+
}
889+
vecT out_shape = shape;
890+
vecT out_strides1 = strides1;
891+
vecT out_strides2 = strides2;
892+
vecT out_strides3 = strides3;
893+
vecT out_strides4 = strides4;
894+
T disp1(0);
895+
T disp2(0);
896+
T disp3(0);
897+
T disp4(0);
898+
899+
int nd = simplify_iteration_four_strides(
900+
dim, out_shape.data(), out_strides1.data(), out_strides2.data(),
901+
out_strides3.data(), out_strides4.data(), disp1, disp2, disp3, disp4);
902+
out_shape.resize(nd);
903+
out_strides1.resize(nd);
904+
out_strides2.resize(nd);
905+
out_strides3.resize(nd);
906+
out_strides4.resize(nd);
907+
return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2,
908+
out_strides3, disp3, out_strides4, disp4);
909+
}
910+
719911
} // namespace strides
720912
} // namespace tensor
721913
} // namespace dpctl

0 commit comments

Comments
 (0)