Skip to content

Commit 36ee0f2

Browse files
committed
Update implementation of isin
permit scalar input for second argument, address some review comments, add docstring
1 parent 4d12620 commit 36ee0f2

File tree

1 file changed

+84
-40
lines changed

1 file changed

+84
-40
lines changed

dpctl/tensor/_set_functions.py

Lines changed: 84 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
from typing import NamedTuple
1818

19+
import dpctl
1920
import dpctl.tensor as dpt
2021
import dpctl.utils as du
2122

2223
from ._copy_utils import _empty_like_orderK
24+
from ._scalar_utils import _get_dtype, _get_queue_usm_type, _validate_dtype
2325
from ._tensor_elementwise_impl import _not_equal, _subtract
2426
from ._tensor_impl import (
2527
_copy_usm_ndarray_into_usm_ndarray,
@@ -36,8 +38,10 @@
3638
_searchsorted_left,
3739
_sort_ascending,
3840
)
41+
from ._type_utils import _resolve_weak_types_all_py_ints
3942

4043
__all__ = [
44+
"isin",
4145
"unique_values",
4246
"unique_counts",
4347
"unique_inverse",
@@ -629,61 +633,101 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
629633

630634

631635
def isin(x, test_elements, /, *, assume_unique=False, invert=False):
636+
"""
637+
Tests `x in test_elements` for each element of `x`. Returns a boolean array
638+
with the same shape as `x` that is `True` where the element is in
639+
`test_elements`, `False` otherwise.
640+
641+
Args:
642+
x (usm_ndarray):
643+
input array.
644+
test_elements (Union[usm_ndarray, bool, int, float, complex]):
645+
elements against which to test each value of `x`.
646+
Default: `None`.
647+
assume_unique (Optional[bool]):
648+
if `True`, the input arrays are both assumed to be unique, which
649+
currently has no effect.
650+
Default: `False`.
651+
invert (Optional[bool]):
652+
if `True`, the output results are inverted, i.e., are equivalent to
653+
testing `x not in test_elements` for each element of `x`.
654+
Default: `False`.
655+
656+
Returns:
657+
usm_ndarray:
658+
an array of the inclusion test results. The returned array has a
659+
boolean data type and the same shape as `x`.
660+
"""
632661
if not isinstance(x, dpt.usm_ndarray):
633662
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
634-
if not isinstance(test_elements, dpt.usm_ndarray):
635-
raise TypeError(
636-
f"Expected dpctl.tensor.usm_ndarray, got {type(test_elements)}"
663+
q1, x_usm_type = x.sycl_queue, x.usm_type
664+
q2, test_usm_type = _get_queue_usm_type(test_elements)
665+
if q2 is None:
666+
exec_q = q1
667+
res_usm_type = x_usm_type
668+
else:
669+
exec_q = dpctl.utils.get_execution_queue((q1, q2))
670+
if exec_q is None:
671+
raise du.ExecutionPlacementError(
672+
"Execution placement can not be unambiguously inferred "
673+
"from input arguments."
674+
)
675+
res_usm_type = dpctl.utils.get_coerced_usm_type(
676+
(
677+
x_usm_type,
678+
test_usm_type,
679+
)
637680
)
681+
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
682+
sycl_dev = exec_q.sycl_device
638683

639-
q = du.get_execution_queue([x.sycl_queue, test_elements.sycl_queue])
640-
if q is None:
641-
raise du.ExecutionPlacementError(
642-
"Execution placement can not be unambiguously "
643-
"inferred from input arguments."
644-
)
684+
x_dt = x.dtype
685+
test_dt = _get_dtype(test_elements, sycl_dev)
686+
if not _validate_dtype(test_dt):
687+
raise ValueError("`test_elements` has unsupported dtype")
645688

646-
x1 = x
647-
x2 = dpt.reshape(test_elements, -1)
689+
dt = dpt.result_type(
690+
*_resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev)
691+
)
648692

649-
x1_dt = x1.dtype
650-
x2_dt = x2.dtype
693+
_manager = du.SequentialOrderManager[exec_q]
651694

652-
_manager = du.SequentialOrderManager[q]
653-
dep_evs = _manager.submitted_events
695+
if x_dt != dt:
696+
x_buf = _empty_like_orderK(x, dt)
697+
dep_evs = _manager.submitted_events
698+
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
699+
src=x, dst=x_buf, sycl_queue=exec_q, depends=dep_evs
700+
)
701+
_manager.add_event_pair(ht_ev, ev)
702+
else:
703+
x_buf = x
654704

655-
if x1_dt != x2_dt:
656-
dt = dpt.result_type(x1, x2)
657-
if x1_dt != dt:
658-
x1_buf = _empty_like_orderK(x1, dt)
659-
dep_evs = _manager.submitted_events
660-
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
661-
src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs
662-
)
663-
_manager.add_event_pair(ht_ev, ev)
664-
x1 = x1_buf
665-
if x2_dt != dt:
666-
x2_buf = _empty_like_orderK(x2, dt)
667-
dep_evs = _manager.submitted_events
668-
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
669-
src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs
670-
)
671-
_manager.add_event_pair(ht_ev, ev)
672-
x2 = x2_buf
705+
if not isinstance(test_elements, dpt.usm_ndarray):
706+
test_buf = dpt.asarray(test_elements, dtype=dt, sycl_queue=exec_q)
707+
elif test_dt != dt:
708+
# copy into C-contiguous memory, because the array will be flattened
709+
test_buf = dpt.empty_like(test_elements, dt, order="C")
710+
dep_evs = _manager.submitted_events
711+
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
712+
src=test_elements, dst=test_buf, sycl_queue=exec_q, depends=dep_evs
713+
)
714+
_manager.add_event_pair(ht_ev, ev)
715+
else:
716+
test_buf = test_elements
673717

674-
x2 = dpt.sort(x2)
718+
test_buf = dpt.reshape(test_buf, -1)
719+
test_buf = dpt.sort(test_buf)
675720

676-
dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
677-
dst = _empty_like_orderK(x1, dpt.bool, usm_type=dst_usm_type)
721+
dst = _empty_like_orderK(x_buf, dpt.bool, usm_type=res_usm_type)
678722

679723
dep_evs = _manager.submitted_events
680724
ht_ev, s_ev = _isin(
681-
needles=x1,
682-
hay=x2,
725+
needles=x_buf,
726+
hay=test_buf,
683727
dst=dst,
684-
sycl_queue=q,
728+
sycl_queue=exec_q,
685729
invert=invert,
686730
depends=dep_evs,
687731
)
688732
_manager.add_event_pair(ht_ev, s_ev)
689-
return dpt.reshape(dst, x.shape)
733+
return dst

0 commit comments

Comments
 (0)