Skip to content

Commit 7f16fae

Browse files
committed
Add test for isin with Python scalar args
also use queue instead of device in _empty_like_orderK call, preventing compute follows data violation
1 parent 1a67133 commit 7f16fae

File tree

2 files changed

+45
-20
lines changed

2 files changed

+45
-20
lines changed

dpctl/tensor/_set_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def isin(
734734
dep_evs = _manager.submitted_events
735735

736736
if x_dt != dt:
737-
x_buf = _empty_like_orderK(x_arr, dt, res_usm_type, sycl_dev)
737+
x_buf = _empty_like_orderK(x_arr, dt, res_usm_type, exec_q)
738738
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
739739
src=x_arr, dst=x_buf, sycl_queue=exec_q, depends=dep_evs
740740
)

dpctl/tests/test_tensor_isin.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,35 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import ctypes
18+
19+
import numpy as np
1720
import pytest
1821

1922
import dpctl.tensor as dpt
2023
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2124
from dpctl.utils import ExecutionPlacementError
2225

23-
24-
@pytest.mark.parametrize(
25-
"dtype",
26-
[
27-
"i1",
28-
"u1",
29-
"i2",
30-
"u2",
31-
"i4",
32-
"u4",
33-
"i8",
34-
"u8",
35-
"f2",
36-
"f4",
37-
"f8",
38-
"c8",
39-
"c16",
40-
],
41-
)
26+
_numeric_dtypes = [
27+
"i1",
28+
"u1",
29+
"i2",
30+
"u2",
31+
"i4",
32+
"u4",
33+
"i8",
34+
"u8",
35+
"f2",
36+
"f4",
37+
"f8",
38+
"c8",
39+
"c16",
40+
]
41+
42+
_all_dtypes = ["?"] + _numeric_dtypes
43+
44+
45+
@pytest.mark.parametrize("dtype", _numeric_dtypes)
4246
def test_isin_basic(dtype):
4347
q = get_queue_or_skip()
4448
skip_if_dtype_not_supported(dtype, q)
@@ -192,3 +196,24 @@ def test_isin_special_floating_point_vals():
192196
test = dpt.asarray(0.0, dtype="f4")
193197
assert dpt.isin(x, test)
194198
assert dpt.isin(test, x)
199+
200+
201+
@pytest.mark.parametrize("dt", _all_dtypes)
202+
def test_isin_py_scalars(dt):
203+
q = get_queue_or_skip()
204+
skip_if_dtype_not_supported(dt, q)
205+
206+
x = dpt.zeros((10, 10), dtype=dt, sycl_queue=q)
207+
py_zeros = (
208+
bool(0),
209+
int(0),
210+
float(0),
211+
complex(0),
212+
np.float32(0),
213+
ctypes.c_int(0),
214+
)
215+
for sc in py_zeros:
216+
r1 = dpt.isin(x, sc)
217+
assert isinstance(r1, dpt.usm_ndarray)
218+
r2 = dpt.isin(sc, x)
219+
assert isinstance(r2, dpt.usm_ndarray)

0 commit comments

Comments
 (0)