Skip to content

Commit 9246e14

Browse files
Introduced dpctl::tensor::overlap::SameLogicalTensor
The call operator of this struct verifies whether two USM ND-arrays logically address the same memory elements. In the case when data-parallel read from and write to arrays that locally address the same memory elements there is no race condition and no additional copying is needed.
1 parent 6561893 commit 9246e14

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

dpctl/tensor/libtensor/include/utils/memory_overlap.hpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,53 @@ struct MemoryOverlap
100100
}
101101
};
102102

103+
struct SameLogicalTensors
104+
{
105+
bool operator()(dpctl::tensor::usm_ndarray ar1,
106+
dpctl::tensor::usm_ndarray ar2) const
107+
{
108+
// Same ndim
109+
int nd1 = ar1.get_ndim();
110+
if (nd1 != ar2.get_ndim())
111+
return false;
112+
113+
// Same dtype
114+
int tn1 = ar1.get_typenum();
115+
if (tn1 != ar2.get_typenum())
116+
return false;
117+
118+
// Same pointer
119+
const char *ar1_data = ar1.get_data();
120+
const char *ar2_data = ar2.get_data();
121+
122+
if (ar1_data != ar2_data)
123+
return false;
124+
125+
// Same shape and strides
126+
const py::ssize_t *ar1_shape = ar1.get_shape_raw();
127+
const py::ssize_t *ar2_shape = ar2.get_shape_raw();
128+
129+
if (!std::equal(ar1_shape, ar1_shape + nd1, ar2_shape))
130+
return false;
131+
132+
// Same shape and strides
133+
auto const &ar1_strides = ar1.get_strides_vector();
134+
auto const &ar2_strides = ar2.get_strides_vector();
135+
136+
auto ar1_beg_it = std::begin(ar1_strides);
137+
auto ar1_end_it = std::end(ar1_strides);
138+
139+
auto ar2_beg_it = std::begin(ar2_strides);
140+
141+
if (!std::equal(ar1_beg_it, ar1_end_it, ar2_beg_it))
142+
return false;
143+
144+
// all checks passed: arrays are logical views
145+
// into the same memory
146+
return true;
147+
}
148+
};
149+
103150
} // namespace overlap
104151
} // namespace tensor
105152
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ py_unary_ufunc(dpctl::tensor::usm_ndarray src,
128128

129129
// check memory overlap
130130
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
131-
if (overlap(src, dst)) {
131+
auto const &same_logical_tensors =
132+
dpctl::tensor::overlap::SameLogicalTensors();
133+
if (overlap(src, dst) && !same_logical_tensors(src, dst)) {
132134
throw py::value_error("Arrays index overlapping segments of memory");
133135
}
134136

0 commit comments

Comments
 (0)