Skip to content

Commit 0fe5ac7

Browse files
Added a test based on invariance of where wrt slicing
1 parent 98c6651 commit 0fe5ac7

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_where_result_types(dt1, dt2, fp16, fp64):
115115

116116

117117
@pytest.mark.parametrize("dt", _all_dtypes)
118-
def test_where_all_dtypes(dt):
118+
def test_where_mask_dtypes(dt):
119119
q = get_queue_or_skip()
120120
skip_if_dtype_not_supported(dt, q)
121121

@@ -311,6 +311,36 @@ def test_where_strided():
311311
assert_array_equal(dpt.asnumpy(res), expected)
312312

313313

314+
def test_where_invariants():
315+
get_queue_or_skip()
316+
317+
test_sh = (
318+
6,
319+
8,
320+
)
321+
mask = dpt.asarray(np.random.choice([True, False], size=test_sh))
322+
p = dpt.ones(test_sh, dtype=dpt.int16)
323+
m = dpt.full(test_sh, -1, dtype=dpt.int16)
324+
inds_list = [
325+
(
326+
np.s_[:3],
327+
np.s_[::2],
328+
),
329+
(
330+
np.s_[::2],
331+
np.s_[::2],
332+
),
333+
(
334+
np.s_[::-1],
335+
np.s_[:],
336+
),
337+
]
338+
for ind in inds_list:
339+
r1 = dpt.where(mask, p, m)[ind]
340+
r2 = dpt.where(mask[ind], p[ind], m[ind])
341+
assert (dpt.asnumpy(r1) == dpt.asnumpy(r2)).all()
342+
343+
314344
def test_where_arg_validation():
315345
get_queue_or_skip()
316346

0 commit comments

Comments
 (0)