Skip to content

Commit db0a082

Browse files
committed
Allow x to be a scalar in isin and remove assume_unique
1 parent 03a0676 commit db0a082

File tree

1 file changed

+60
-33
lines changed

1 file changed

+60
-33
lines changed

dpctl/tensor/_set_functions.py

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
import dpctl.utils as du
2222

2323
from ._copy_utils import _empty_like_orderK
24-
from ._scalar_utils import _get_dtype, _get_queue_usm_type, _validate_dtype
24+
from ._scalar_utils import (
25+
_get_dtype,
26+
_get_queue_usm_type,
27+
_get_shape,
28+
_validate_dtype,
29+
)
2530
from ._tensor_elementwise_impl import _not_equal, _subtract
2631
from ._tensor_impl import (
2732
_copy_usm_ndarray_into_usm_ndarray,
@@ -38,7 +43,10 @@
3843
_searchsorted_left,
3944
_sort_ascending,
4045
)
41-
from ._type_utils import _resolve_weak_types_all_py_ints
46+
from ._type_utils import (
47+
_resolve_weak_types_all_py_ints,
48+
_to_device_supported_dtype,
49+
)
4250

4351
__all__ = [
4452
"isin",
@@ -632,21 +640,17 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
632640
)
633641

634642

635-
def isin(x, test_elements, /, *, assume_unique=False, invert=False):
643+
def isin(x, test_elements, /, *, invert=False):
636644
"""
637645
Tests `x in test_elements` for each element of `x`. Returns a boolean array
638646
with the same shape as `x` that is `True` where the element is in
639647
`test_elements`, `False` otherwise.
640648
641649
Args:
642-
x (usm_ndarray):
643-
input array.
650+
x (Union[usm_ndarray, bool, int, float, complex]):
651+
input element or elements.
644652
test_elements (Union[usm_ndarray, bool, int, float, complex]):
645653
elements against which to test each value of `x`.
646-
assume_unique (Optional[bool]):
647-
if `True`, the input arrays are both assumed to be unique, which
648-
currently has no effect.
649-
Default: `False`.
650654
invert (Optional[bool]):
651655
if `True`, the output results are inverted, i.e., are equivalent to
652656
testing `x not in test_elements` for each element of `x`.
@@ -657,11 +661,19 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
657661
an array of the inclusion test results. The returned array has a
658662
boolean data type and the same shape as `x`.
659663
"""
660-
if not isinstance(x, dpt.usm_ndarray):
661-
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
662-
q1, x_usm_type = x.sycl_queue, x.usm_type
664+
q1, x_usm_type = _get_queue_usm_type(x)
663665
q2, test_usm_type = _get_queue_usm_type(test_elements)
664-
if q2 is None:
666+
if q1 is None and q2 is None:
667+
raise du.ExecutionPlacementError(
668+
"Execution placement can not be unambiguously inferred "
669+
"from input arguments. "
670+
"One of the arguments must represent USM allocation and "
671+
"expose `__sycl_usm_array_interface__` property"
672+
)
673+
if q1 is None:
674+
exec_q = q2
675+
res_usm_type = test_usm_type
676+
elif q2 is None:
665677
exec_q = q1
666678
res_usm_type = x_usm_type
667679
else:
@@ -680,45 +692,60 @@ def isin(x, test_elements, /, *, assume_unique=False, invert=False):
680692
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
681693
sycl_dev = exec_q.sycl_device
682694

695+
x_dt = _get_dtype(x, sycl_dev)
696+
test_dt = _get_dtype(test_elements, sycl_dev)
697+
if not all(_validate_dtype(dt) for dt in (x_dt, test_dt)):
698+
raise ValueError("Operands have unsupported data types")
699+
700+
x_sh = _get_shape(x)
683701
if isinstance(test_elements, dpt.usm_ndarray) and test_elements.size == 0:
684702
if invert:
685-
return dpt.ones_like(x, dtype=dpt.bool, usm_type=res_usm_type)
703+
return dpt.ones(
704+
x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q
705+
)
686706
else:
687-
return dpt.zeros_like(x, dtype=dpt.bool, usm_type=res_usm_type)
707+
return dpt.zeros(
708+
x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q
709+
)
688710

689-
x_dt = x.dtype
690-
test_dt = _get_dtype(test_elements, sycl_dev)
691-
if not _validate_dtype(test_dt):
692-
raise ValueError("`test_elements` has unsupported dtype")
711+
dt1, dt2 = _resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev)
712+
dt = _to_device_supported_dtype(dpt.result_type(dt1, dt2), sycl_dev)
713+
714+
if not isinstance(x, dpt.usm_ndarray):
715+
x_arr = dpt.asarray(
716+
x, dtype=dt1, usm_type=res_usm_type, sycl_queue=exec_q
717+
)
718+
else:
719+
x_arr = x
720+
721+
if not isinstance(test_elements, dpt.usm_ndarray):
722+
test_arr = dpt.asarray(
723+
test_elements, dtype=dt2, usm_type=res_usm_type, sycl_queue=exec_q
724+
)
725+
else:
726+
test_arr = test_elements
693727

694728
_manager = du.SequentialOrderManager[exec_q]
695729
dep_evs = _manager.submitted_events
696730

697-
dt1, dt2 = _resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev)
698-
dt = dpt.result_type(dt1, dt2)
699-
700731
if x_dt != dt:
701-
x_buf = _empty_like_orderK(x, dt)
732+
x_buf = _empty_like_orderK(x_arr, dt, res_usm_type, sycl_dev)
702733
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
703-
src=x, dst=x_buf, sycl_queue=exec_q, depends=dep_evs
734+
src=x_arr, dst=x_buf, sycl_queue=exec_q, depends=dep_evs
704735
)
705736
_manager.add_event_pair(ht_ev, ev)
706737
else:
707-
x_buf = x
738+
x_buf = x_arr
708739

709-
if not isinstance(test_elements, dpt.usm_ndarray):
710-
test_buf = dpt.asarray(
711-
test_elements, dtype=dt, usm_type=res_usm_type, sycl_queue=exec_q
712-
)
713-
elif test_dt != dt:
740+
if test_dt != dt:
714741
# copy into C-contiguous memory, because the array will be flattened
715-
test_buf = dpt.empty_like(test_elements, dtype=dt, order="C")
742+
test_buf = dpt.empty_like(test_arr, dtype=dt, order="C")
716743
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
717-
src=test_elements, dst=test_buf, sycl_queue=exec_q, depends=dep_evs
744+
src=test_arr, dst=test_buf, sycl_queue=exec_q, depends=dep_evs
718745
)
719746
_manager.add_event_pair(ht_ev, ev)
720747
else:
721-
test_buf = test_elements
748+
test_buf = test_arr
722749

723750
test_buf = dpt.reshape(test_buf, -1)
724751
test_buf = dpt.sort(test_buf)

0 commit comments

Comments
 (0)