@@ -643,7 +643,6 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
643
643
input array.
644
644
test_elements (Union[usm_ndarray, bool, int, float, complex]):
645
645
elements against which to test each value of `x`.
646
- Default: `None`.
647
646
assume_unique (Optional[bool]):
648
647
if `True`, the input arrays are both assumed to be unique, which
649
648
currently has no effect.
@@ -681,20 +680,25 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
681
680
dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
682
681
sycl_dev = exec_q .sycl_device
683
682
683
+ if isinstance (test_elements , dpt .usm_ndarray ) and test_elements .size == 0 :
684
+ if invert :
685
+ return dpt .ones_like (x , dtype = dpt .bool , usm_type = res_usm_type )
686
+ else :
687
+ return dpt .zeros_like (x , dtype = dpt .bool , usm_type = res_usm_type )
688
+
684
689
x_dt = x .dtype
685
690
test_dt = _get_dtype (test_elements , sycl_dev )
686
691
if not _validate_dtype (test_dt ):
687
692
raise ValueError ("`test_elements` has unsupported dtype" )
688
693
689
- dt = dpt .result_type (
690
- * _resolve_weak_types_all_py_ints (x_dt , test_dt , sycl_dev )
691
- )
692
-
693
694
_manager = du .SequentialOrderManager [exec_q ]
695
+ dep_evs = _manager .submitted_events
696
+
697
+ dt1 , dt2 = _resolve_weak_types_all_py_ints (x_dt , test_dt , sycl_dev )
698
+ dt = dpt .result_type (dt1 , dt2 )
694
699
695
700
if x_dt != dt :
696
701
x_buf = _empty_like_orderK (x , dt )
697
- dep_evs = _manager .submitted_events
698
702
ht_ev , ev = _copy_usm_ndarray_into_usm_ndarray (
699
703
src = x , dst = x_buf , sycl_queue = exec_q , depends = dep_evs
700
704
)
@@ -703,11 +707,12 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
703
707
x_buf = x
704
708
705
709
if not isinstance (test_elements , dpt .usm_ndarray ):
706
- test_buf = dpt .asarray (test_elements , dtype = dt , sycl_queue = exec_q )
710
+ test_buf = dpt .asarray (
711
+ test_elements , dtype = dt , usm_type = res_usm_type , sycl_queue = exec_q
712
+ )
707
713
elif test_dt != dt :
708
714
# copy into C-contiguous memory, because the array will be flattened
709
- test_buf = dpt .empty_like (test_elements , dt , order = "C" )
710
- dep_evs = _manager .submitted_events
715
+ test_buf = dpt .empty_like (test_elements , dtype = dt , order = "C" )
711
716
ht_ev , ev = _copy_usm_ndarray_into_usm_ndarray (
712
717
src = test_elements , dst = test_buf , sycl_queue = exec_q , depends = dep_evs
713
718
)
@@ -718,7 +723,9 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
718
723
test_buf = dpt .reshape (test_buf , - 1 )
719
724
test_buf = dpt .sort (test_buf )
720
725
721
- dst = _empty_like_orderK (x_buf , dpt .bool , usm_type = res_usm_type )
726
+ dst = dpt .empty_like (
727
+ x_buf , dtype = dpt .bool , usm_type = res_usm_type , order = "C"
728
+ )
722
729
723
730
dep_evs = _manager .submitted_events
724
731
ht_ev , s_ev = _isin (
0 commit comments