Skip to content

Commit 277747c

Browse files
committed
Add test for combinations of dtypes as inputs to isin
1 parent 7f16fae commit 277747c

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

dpctl/tests/test_tensor_isin.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,35 @@ def test_isin_strided_bool():
142142
assert r2.shape == x_s.shape
143143

144144

145+
@pytest.mark.parametrize("dt1", _numeric_dtypes)
146+
@pytest.mark.parametrize("dt2", _numeric_dtypes)
147+
def test_isin_dtype_matrix(dt1, dt2):
148+
q = get_queue_or_skip()
149+
skip_if_dtype_not_supported(dt1, q)
150+
skip_if_dtype_not_supported(dt2, q)
151+
152+
sz = 10
153+
x = dpt.asarray([0, 1, 11], dtype=dt1, sycl_queue=q)
154+
test1 = dpt.arange(sz, dtype=dt2, sycl_queue=q)
155+
156+
r1 = dpt.isin(x, test1)
157+
assert isinstance(r1, dpt.usm_ndarray)
158+
assert r1.dtype == dpt.bool
159+
assert r1.shape == x.shape
160+
assert not r1[-1]
161+
assert dpt.all(r1[0:-1])
162+
assert r1.sycl_queue == x.sycl_queue
163+
164+
test2 = dpt.tile(dpt.asarray([[0, 1]], dtype=dt2, sycl_queue=q).mT, 2)
165+
r2 = dpt.isin(x, test2)
166+
assert isinstance(r2, dpt.usm_ndarray)
167+
assert r2.dtype == dpt.bool
168+
assert r2.shape == x.shape
169+
assert not r2[-1]
170+
assert dpt.all(r1[0:-1])
171+
assert r2.sycl_queue == x.sycl_queue
172+
173+
145174
def test_isin_empty_inputs():
146175
get_queue_or_skip()
147176

0 commit comments

Comments
 (0)