Skip to content

Commit 172bb57

Browse files
committed
Add argsort and sort wrappers
1 parent 3de4460 commit 172bb57

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

numpy_array_api_compat/_aliases.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,39 @@ def reshape(x: ndarray, /, shape: Tuple[int, ...], copy: Optional[bool] = None)
277277
return x
278278
return np.reshape(x, shape)
279279

280+
# The descending keyword is new in sort and argsort, and 'kind' replaced with
281+
# 'stable'
282+
def argsort(
283+
x: ndarray, /, *, axis: int = -1, descending: bool = False, stable: bool = True
284+
) -> ndarray:
285+
# Note: this keyword argument is different, and the default is different.
286+
kind = "stable" if stable else "quicksort"
287+
if not descending:
288+
res = np.argsort(x, axis=axis, kind=kind)
289+
else:
290+
# As NumPy has no native descending sort, we imitate it here. Note that
291+
# simply flipping the results of np.argsort(x, ...) would not
292+
# respect the relative order like it would in native descending sorts.
293+
res = np.flip(
294+
np.argsort(np.flip(x, axis=axis), axis=axis, kind=kind),
295+
axis=axis,
296+
)
297+
# Rely on flip()/argsort() to validate axis
298+
normalised_axis = axis if axis >= 0 else x.ndim + axis
299+
max_i = x.shape[normalised_axis] - 1
300+
res = max_i - res
301+
return res
302+
303+
def sort(
304+
x: ndarray, /, *, axis: int = -1, descending: bool = False, stable: bool = True
305+
) -> ndarray:
306+
# Note: this keyword argument is different, and the default is different.
307+
kind = "stable" if stable else "quicksort"
308+
res = np.sort(x, axis=axis, kind=kind)
309+
if descending:
310+
res = np.flip(res, axis=axis)
311+
return res
312+
280313
# from numpy import * doesn't overwrite these builtin names
281314
from numpy import abs, max, min, round
282315

@@ -287,4 +320,5 @@ def reshape(x: ndarray, /, shape: Tuple[int, ...], copy: Optional[bool] = None)
287320
'unique_inverse', 'unique_values', 'astype', 'abs', 'max', 'min',
288321
'round', 'std', 'var', 'permute_dims', 'asarray', 'arange',
289322
'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace',
290-
'ones', 'ones_like', 'zeros', 'zeros_like', 'reshape']
323+
'ones', 'ones_like', 'zeros', 'zeros_like', 'reshape', 'argsort',
324+
'sort']

0 commit comments

Comments
 (0)