Skip to content

Commit 1c4460d

Browse files
committed
Add searchsorted
As far as I can tell, except for the dtype restriction, the standard is the same as NumPy.
1 parent 095be2f commit 1c4460d

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,9 @@
286286

287287
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack"]
288288

289-
from ._searching_functions import argmax, argmin, nonzero, where
289+
from ._searching_functions import argmax, argmin, nonzero, searchsorted, where
290290

291-
__all__ += ["argmax", "argmin", "nonzero", "where"]
291+
__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"]
292292

293293
from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
294294

array_api_strict/_searching_functions.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from ._array_object import Array
44
from ._dtypes import _result_type, _real_numeric_dtypes
5-
from ._flags import requires_data_dependent_shapes
5+
from ._flags import requires_data_dependent_shapes, requires_api_version
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
9-
from typing import Optional, Tuple
9+
from typing import Literal, Optional, Tuple
1010

1111
import numpy as np
1212

@@ -45,6 +45,26 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]:
4545
raise ValueError("nonzero is not allowed on 0-dimensional arrays")
4646
return tuple(Array._new(i) for i in np.nonzero(x._array))
4747

48+
@requires_api_version('2023.12')
49+
def searchsorted(
50+
x1: Array,
51+
x2: Array,
52+
/,
53+
*,
54+
side: Literal["left", "right"] = "left",
55+
sorter: Optional[Array] = None,
56+
) -> Array:
57+
"""
58+
Array API compatible wrapper for :py:func:`np.searchsorted <numpy.searchsorted>`.
59+
60+
See its docstring for more information.
61+
"""
62+
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
63+
raise TypeError("Only real numeric dtypes are allowed in searchsorted")
64+
sorter = sorter._array if sorter is not None else None
65+
# TODO: The sort order of nans and signed zeros is implementation
66+
# dependent. Should we error/warn if they are present?
67+
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter))
4868

4969
def where(condition: Array, x1: Array, x2: Array, /) -> Array:
5070
"""

0 commit comments

Comments
 (0)