Skip to content

Commit 03a0676

Browse files
committed
Update per review comments
1 parent 36ee0f2 commit 03a0676

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

dpctl/tensor/_set_functions.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,6 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
643643
input array.
644644
test_elements (Union[usm_ndarray, bool, int, float, complex]):
645645
elements against which to test each value of `x`.
646-
Default: `None`.
647646
assume_unique (Optional[bool]):
648647
if `True`, the input arrays are both assumed to be unique, which
649648
currently has no effect.
@@ -681,20 +680,25 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
681680
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
682681
sycl_dev = exec_q.sycl_device
683682

683+
if isinstance(test_elements, dpt.usm_ndarray) and test_elements.size == 0:
684+
if invert:
685+
return dpt.ones_like(x, dtype=dpt.bool, usm_type=res_usm_type)
686+
else:
687+
return dpt.zeros_like(x, dtype=dpt.bool, usm_type=res_usm_type)
688+
684689
x_dt = x.dtype
685690
test_dt = _get_dtype(test_elements, sycl_dev)
686691
if not _validate_dtype(test_dt):
687692
raise ValueError("`test_elements` has unsupported dtype")
688693

689-
dt = dpt.result_type(
690-
*_resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev)
691-
)
692-
693694
_manager = du.SequentialOrderManager[exec_q]
695+
dep_evs = _manager.submitted_events
696+
697+
dt1, dt2 = _resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev)
698+
dt = dpt.result_type(dt1, dt2)
694699

695700
if x_dt != dt:
696701
x_buf = _empty_like_orderK(x, dt)
697-
dep_evs = _manager.submitted_events
698702
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
699703
src=x, dst=x_buf, sycl_queue=exec_q, depends=dep_evs
700704
)
@@ -703,11 +707,12 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
703707
x_buf = x
704708

705709
if not isinstance(test_elements, dpt.usm_ndarray):
706-
test_buf = dpt.asarray(test_elements, dtype=dt, sycl_queue=exec_q)
710+
test_buf = dpt.asarray(
711+
test_elements, dtype=dt, usm_type=res_usm_type, sycl_queue=exec_q
712+
)
707713
elif test_dt != dt:
708714
# copy into C-contiguous memory, because the array will be flattened
709-
test_buf = dpt.empty_like(test_elements, dt, order="C")
710-
dep_evs = _manager.submitted_events
715+
test_buf = dpt.empty_like(test_elements, dtype=dt, order="C")
711716
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
712717
src=test_elements, dst=test_buf, sycl_queue=exec_q, depends=dep_evs
713718
)
@@ -718,7 +723,9 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
718723
test_buf = dpt.reshape(test_buf, -1)
719724
test_buf = dpt.sort(test_buf)
720725

721-
dst = _empty_like_orderK(x_buf, dpt.bool, usm_type=res_usm_type)
726+
dst = dpt.empty_like(
727+
x_buf, dtype=dpt.bool, usm_type=res_usm_type, order="C"
728+
)
722729

723730
dep_evs = _manager.submitted_events
724731
ht_ev, s_ev = _isin(

0 commit comments

Comments
 (0)