21
21
import dpctl .utils as du
22
22
23
23
from ._copy_utils import _empty_like_orderK
24
- from ._scalar_utils import _get_dtype , _get_queue_usm_type , _validate_dtype
24
+ from ._scalar_utils import (
25
+ _get_dtype ,
26
+ _get_queue_usm_type ,
27
+ _get_shape ,
28
+ _validate_dtype ,
29
+ )
25
30
from ._tensor_elementwise_impl import _not_equal , _subtract
26
31
from ._tensor_impl import (
27
32
_copy_usm_ndarray_into_usm_ndarray ,
38
43
_searchsorted_left ,
39
44
_sort_ascending ,
40
45
)
41
- from ._type_utils import _resolve_weak_types_all_py_ints
46
+ from ._type_utils import (
47
+ _resolve_weak_types_all_py_ints ,
48
+ _to_device_supported_dtype ,
49
+ )
42
50
43
51
__all__ = [
44
52
"isin" ,
@@ -632,21 +640,17 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
632
640
)
633
641
634
642
635
- def isin (x , test_elements , / , * , assume_unique = False , invert = False ):
643
+ def isin (x , test_elements , / , * , invert = False ):
636
644
"""
637
645
Tests `x in test_elements` for each element of `x`. Returns a boolean array
638
646
with the same shape as `x` that is `True` where the element is in
639
647
`test_elements`, `False` otherwise.
640
648
641
649
Args:
642
- x (usm_ndarray):
643
- input array .
650
+ x (Union[ usm_ndarray, bool, int, float, complex] ):
651
+ input element or elements .
644
652
test_elements (Union[usm_ndarray, bool, int, float, complex]):
645
653
elements against which to test each value of `x`.
646
- assume_unique (Optional[bool]):
647
- if `True`, the input arrays are both assumed to be unique, which
648
- currently has no effect.
649
- Default: `False`.
650
654
invert (Optional[bool]):
651
655
if `True`, the output results are inverted, i.e., are equivalent to
652
656
testing `x not in test_elements` for each element of `x`.
@@ -657,11 +661,19 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
657
661
an array of the inclusion test results. The returned array has a
658
662
boolean data type and the same shape as `x`.
659
663
"""
660
- if not isinstance (x , dpt .usm_ndarray ):
661
- raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
662
- q1 , x_usm_type = x .sycl_queue , x .usm_type
664
+ q1 , x_usm_type = _get_queue_usm_type (x )
663
665
q2 , test_usm_type = _get_queue_usm_type (test_elements )
664
- if q2 is None :
666
+ if q1 is None and q2 is None :
667
+ raise du .ExecutionPlacementError (
668
+ "Execution placement can not be unambiguously inferred "
669
+ "from input arguments. "
670
+ "One of the arguments must represent USM allocation and "
671
+ "expose `__sycl_usm_array_interface__` property"
672
+ )
673
+ if q1 is None :
674
+ exec_q = q2
675
+ res_usm_type = test_usm_type
676
+ elif q2 is None :
665
677
exec_q = q1
666
678
res_usm_type = x_usm_type
667
679
else :
@@ -680,45 +692,60 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
680
692
dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
681
693
sycl_dev = exec_q .sycl_device
682
694
695
+ x_dt = _get_dtype (x , sycl_dev )
696
+ test_dt = _get_dtype (test_elements , sycl_dev )
697
+ if not all (_validate_dtype (dt ) for dt in (x_dt , test_dt )):
698
+ raise ValueError ("Operands have unsupported data types" )
699
+
700
+ x_sh = _get_shape (x )
683
701
if isinstance (test_elements , dpt .usm_ndarray ) and test_elements .size == 0 :
684
702
if invert :
685
- return dpt .ones_like (x , dtype = dpt .bool , usm_type = res_usm_type )
703
+ return dpt .ones (
704
+ x_sh , dtype = dpt .bool , usm_type = res_usm_type , sycl_queue = exec_q
705
+ )
686
706
else :
687
- return dpt .zeros_like (x , dtype = dpt .bool , usm_type = res_usm_type )
707
+ return dpt .zeros (
708
+ x_sh , dtype = dpt .bool , usm_type = res_usm_type , sycl_queue = exec_q
709
+ )
688
710
689
- x_dt = x .dtype
690
- test_dt = _get_dtype (test_elements , sycl_dev )
691
- if not _validate_dtype (test_dt ):
692
- raise ValueError ("`test_elements` has unsupported dtype" )
711
+ dt1 , dt2 = _resolve_weak_types_all_py_ints (x_dt , test_dt , sycl_dev )
712
+ dt = _to_device_supported_dtype (dpt .result_type (dt1 , dt2 ), sycl_dev )
713
+
714
+ if not isinstance (x , dpt .usm_ndarray ):
715
+ x_arr = dpt .asarray (
716
+ x , dtype = dt1 , usm_type = res_usm_type , sycl_queue = exec_q
717
+ )
718
+ else :
719
+ x_arr = x
720
+
721
+ if not isinstance (test_elements , dpt .usm_ndarray ):
722
+ test_arr = dpt .asarray (
723
+ test_elements , dtype = dt2 , usm_type = res_usm_type , sycl_queue = exec_q
724
+ )
725
+ else :
726
+ test_arr = test_elements
693
727
694
728
_manager = du .SequentialOrderManager [exec_q ]
695
729
dep_evs = _manager .submitted_events
696
730
697
- dt1 , dt2 = _resolve_weak_types_all_py_ints (x_dt , test_dt , sycl_dev )
698
- dt = dpt .result_type (dt1 , dt2 )
699
-
700
731
if x_dt != dt :
701
- x_buf = _empty_like_orderK (x , dt )
732
+ x_buf = _empty_like_orderK (x_arr , dt , res_usm_type , sycl_dev )
702
733
ht_ev , ev = _copy_usm_ndarray_into_usm_ndarray (
703
- src = x , dst = x_buf , sycl_queue = exec_q , depends = dep_evs
734
+ src = x_arr , dst = x_buf , sycl_queue = exec_q , depends = dep_evs
704
735
)
705
736
_manager .add_event_pair (ht_ev , ev )
706
737
else :
707
- x_buf = x
738
+ x_buf = x_arr
708
739
709
- if not isinstance (test_elements , dpt .usm_ndarray ):
710
- test_buf = dpt .asarray (
711
- test_elements , dtype = dt , usm_type = res_usm_type , sycl_queue = exec_q
712
- )
713
- elif test_dt != dt :
740
+ if test_dt != dt :
714
741
# copy into C-contiguous memory, because the array will be flattened
715
- test_buf = dpt .empty_like (test_elements , dtype = dt , order = "C" )
742
+ test_buf = dpt .empty_like (test_arr , dtype = dt , order = "C" )
716
743
ht_ev , ev = _copy_usm_ndarray_into_usm_ndarray (
717
- src = test_elements , dst = test_buf , sycl_queue = exec_q , depends = dep_evs
744
+ src = test_arr , dst = test_buf , sycl_queue = exec_q , depends = dep_evs
718
745
)
719
746
_manager .add_event_pair (ht_ev , ev )
720
747
else :
721
- test_buf = test_elements
748
+ test_buf = test_arr
722
749
723
750
test_buf = dpt .reshape (test_buf , - 1 )
724
751
test_buf = dpt .sort (test_buf )
0 commit comments