Skip to content

Commit cc6cc4d

Browse files
committed
Update dtype strictness in array_api searching and sorting functions
Original NumPy Commit: 837b1af70ecea4877c8b1fee327d73d6dace517a
1 parent 939738e commit cc6cc4d

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

array_api_strict/_searching_functions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from ._array_object import Array
4-
from ._dtypes import _result_type
4+
from ._dtypes import _result_type, _real_numeric_dtypes
55

66
from typing import Optional, Tuple
77

@@ -14,6 +14,8 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
1414
1515
See its docstring for more information.
1616
"""
17+
if x.dtype not in _real_numeric_dtypes:
18+
raise TypeError("Only real numeric dtypes are allowed in argmax")
1719
return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)))
1820

1921

@@ -23,6 +25,8 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
2325
2426
See its docstring for more information.
2527
"""
28+
if x.dtype not in _real_numeric_dtypes:
29+
raise TypeError("Only real numeric dtypes are allowed in argmin")
2630
return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims)))
2731

2832

array_api_strict/_sorting_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from ._array_object import Array
4+
from ._dtypes import _real_numeric_dtypes
45

56
import numpy as np
67

@@ -14,6 +15,8 @@ def argsort(
1415
1516
See its docstring for more information.
1617
"""
18+
if x.dtype not in _real_numeric_dtypes:
19+
raise TypeError("Only real numeric dtypes are allowed in argsort")
1720
# Note: this keyword argument is different, and the default is different.
1821
kind = "stable" if stable else "quicksort"
1922
if not descending:
@@ -41,6 +44,8 @@ def sort(
4144
4245
See its docstring for more information.
4346
"""
47+
if x.dtype not in _real_numeric_dtypes:
48+
raise TypeError("Only real numeric dtypes are allowed in sort")
4449
# Note: this keyword argument is different, and the default is different.
4550
kind = "stable" if stable else "quicksort"
4651
res = np.sort(x._array, axis=axis, kind=kind)

0 commit comments

Comments
 (0)