|
16 | 16 |
|
17 | 17 | from typing import NamedTuple
|
18 | 18 |
|
| 19 | +import dpctl |
19 | 20 | import dpctl.tensor as dpt
|
20 | 21 | import dpctl.utils as du
|
21 | 22 |
|
22 | 23 | from ._copy_utils import _empty_like_orderK
|
| 24 | +from ._scalar_utils import _get_dtype, _get_queue_usm_type, _validate_dtype |
23 | 25 | from ._tensor_elementwise_impl import _not_equal, _subtract
|
24 | 26 | from ._tensor_impl import (
|
25 | 27 | _copy_usm_ndarray_into_usm_ndarray,
|
|
36 | 38 | _searchsorted_left,
|
37 | 39 | _sort_ascending,
|
38 | 40 | )
|
| 41 | +from ._type_utils import _resolve_weak_types_all_py_ints |
39 | 42 |
|
40 | 43 | __all__ = [
|
| 44 | + "isin", |
41 | 45 | "unique_values",
|
42 | 46 | "unique_counts",
|
43 | 47 | "unique_inverse",
|
@@ -629,61 +633,101 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
|
629 | 633 |
|
630 | 634 |
|
631 | 635 | def isin(x, test_elements, /, *, assume_unique=False, invert=False):
|
| 636 | + """ |
| 637 | + Tests `x in test_elements` for each element of `x`. Returns a boolean array |
| 638 | + with the same shape as `x` that is `True` where the element is in |
| 639 | + `test_elements`, `False` otherwise. |
| 640 | +
|
| 641 | + Args: |
| 642 | + x (usm_ndarray): |
| 643 | + input array. |
| 644 | + test_elements (Union[usm_ndarray, bool, int, float, complex]): |
| 645 | + elements against which to test each value of `x`. |
| 646 | + Default: `None`. |
| 647 | + assume_unique (Optional[bool]): |
| 648 | + if `True`, the input arrays are both assumed to be unique, which |
| 649 | + currently has no effect. |
| 650 | + Default: `False`. |
| 651 | + invert (Optional[bool]): |
| 652 | + if `True`, the output results are inverted, i.e., are equivalent to |
| 653 | + testing `x not in test_elements` for each element of `x`. |
| 654 | + Default: `False`. |
| 655 | +
|
| 656 | + Returns: |
| 657 | + usm_ndarray: |
| 658 | + an array of the inclusion test results. The returned array has a |
| 659 | + boolean data type and the same shape as `x`. |
| 660 | + """ |
632 | 661 | if not isinstance(x, dpt.usm_ndarray):
|
633 | 662 | raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
|
634 |
| - if not isinstance(test_elements, dpt.usm_ndarray): |
635 |
| - raise TypeError( |
636 |
| - f"Expected dpctl.tensor.usm_ndarray, got {type(test_elements)}" |
| 663 | + q1, x_usm_type = x.sycl_queue, x.usm_type |
| 664 | + q2, test_usm_type = _get_queue_usm_type(test_elements) |
| 665 | + if q2 is None: |
| 666 | + exec_q = q1 |
| 667 | + res_usm_type = x_usm_type |
| 668 | + else: |
| 669 | + exec_q = dpctl.utils.get_execution_queue((q1, q2)) |
| 670 | + if exec_q is None: |
| 671 | + raise du.ExecutionPlacementError( |
| 672 | + "Execution placement can not be unambiguously inferred " |
| 673 | + "from input arguments." |
| 674 | + ) |
| 675 | + res_usm_type = dpctl.utils.get_coerced_usm_type( |
| 676 | + ( |
| 677 | + x_usm_type, |
| 678 | + test_usm_type, |
| 679 | + ) |
637 | 680 | )
|
| 681 | + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) |
| 682 | + sycl_dev = exec_q.sycl_device |
638 | 683 |
|
639 |
| - q = du.get_execution_queue([x.sycl_queue, test_elements.sycl_queue]) |
640 |
| - if q is None: |
641 |
| - raise du.ExecutionPlacementError( |
642 |
| - "Execution placement can not be unambiguously " |
643 |
| - "inferred from input arguments." |
644 |
| - ) |
| 684 | + x_dt = x.dtype |
| 685 | + test_dt = _get_dtype(test_elements, sycl_dev) |
| 686 | + if not _validate_dtype(test_dt): |
| 687 | + raise ValueError("`test_elements` has unsupported dtype") |
645 | 688 |
|
646 |
| - x1 = x |
647 |
| - x2 = dpt.reshape(test_elements, -1) |
| 689 | + dt = dpt.result_type( |
| 690 | + *_resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev) |
| 691 | + ) |
648 | 692 |
|
649 |
| - x1_dt = x1.dtype |
650 |
| - x2_dt = x2.dtype |
| 693 | + _manager = du.SequentialOrderManager[exec_q] |
651 | 694 |
|
652 |
| - _manager = du.SequentialOrderManager[q] |
653 |
| - dep_evs = _manager.submitted_events |
| 695 | + if x_dt != dt: |
| 696 | + x_buf = _empty_like_orderK(x, dt) |
| 697 | + dep_evs = _manager.submitted_events |
| 698 | + ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( |
| 699 | + src=x, dst=x_buf, sycl_queue=exec_q, depends=dep_evs |
| 700 | + ) |
| 701 | + _manager.add_event_pair(ht_ev, ev) |
| 702 | + else: |
| 703 | + x_buf = x |
654 | 704 |
|
655 |
| - if x1_dt != x2_dt: |
656 |
| - dt = dpt.result_type(x1, x2) |
657 |
| - if x1_dt != dt: |
658 |
| - x1_buf = _empty_like_orderK(x1, dt) |
659 |
| - dep_evs = _manager.submitted_events |
660 |
| - ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( |
661 |
| - src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs |
662 |
| - ) |
663 |
| - _manager.add_event_pair(ht_ev, ev) |
664 |
| - x1 = x1_buf |
665 |
| - if x2_dt != dt: |
666 |
| - x2_buf = _empty_like_orderK(x2, dt) |
667 |
| - dep_evs = _manager.submitted_events |
668 |
| - ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( |
669 |
| - src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs |
670 |
| - ) |
671 |
| - _manager.add_event_pair(ht_ev, ev) |
672 |
| - x2 = x2_buf |
| 705 | + if not isinstance(test_elements, dpt.usm_ndarray): |
| 706 | + test_buf = dpt.asarray(test_elements, dtype=dt, sycl_queue=exec_q) |
| 707 | + elif test_dt != dt: |
| 708 | + # copy into C-contiguous memory, because the array will be flattened |
| 709 | + test_buf = dpt.empty_like(test_elements, dt, order="C") |
| 710 | + dep_evs = _manager.submitted_events |
| 711 | + ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( |
| 712 | + src=test_elements, dst=test_buf, sycl_queue=exec_q, depends=dep_evs |
| 713 | + ) |
| 714 | + _manager.add_event_pair(ht_ev, ev) |
| 715 | + else: |
| 716 | + test_buf = test_elements |
673 | 717 |
|
674 |
| - x2 = dpt.sort(x2) |
| 718 | + test_buf = dpt.reshape(test_buf, -1) |
| 719 | + test_buf = dpt.sort(test_buf) |
675 | 720 |
|
676 |
| - dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type]) |
677 |
| - dst = _empty_like_orderK(x1, dpt.bool, usm_type=dst_usm_type) |
| 721 | + dst = _empty_like_orderK(x_buf, dpt.bool, usm_type=res_usm_type) |
678 | 722 |
|
679 | 723 | dep_evs = _manager.submitted_events
|
680 | 724 | ht_ev, s_ev = _isin(
|
681 |
| - needles=x1, |
682 |
| - hay=x2, |
| 725 | + needles=x_buf, |
| 726 | + hay=test_buf, |
683 | 727 | dst=dst,
|
684 |
| - sycl_queue=q, |
| 728 | + sycl_queue=exec_q, |
685 | 729 | invert=invert,
|
686 | 730 | depends=dep_evs,
|
687 | 731 | )
|
688 | 732 | _manager.add_event_pair(ht_ev, s_ev)
|
689 |
| - return dpt.reshape(dst, x.shape) |
| 733 | + return dst |
0 commit comments