Skip to content

Commit c0a47fd

Browse files
committed
Cover most things for test_argsort
1 parent 5b44997 commit c0a47fd

File tree

1 file changed

+42
-8
lines changed

1 file changed

+42
-8
lines changed

array_api_tests/test_sorting.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,49 @@
77
from . import hypothesis_helpers as hh
88
from . import pytest_helpers as ph
99
from . import xps
10-
from .test_manipulation_functions import assert_equals
11-
from .test_statistical_functions import axes_ndindex, normalise_axis
10+
from .test_manipulation_functions import assert_equals as assert_equals_
11+
from .test_searching_functions import assert_default_index
12+
from .test_statistical_functions import assert_equals, axes_ndindex, normalise_axis
1213

1314

14-
# TODO: generate kwargs
15-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
16-
def test_argsort(x):
17-
xp.argsort(x)
18-
# TODO
15+
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
16+
@given(
17+
x=xps.arrays(
18+
dtype=xps.scalar_dtypes(),
19+
shape=hh.shapes(min_dims=1, min_side=1),
20+
elements={"allow_nan": False},
21+
),
22+
data=st.data(),
23+
)
24+
def test_argsort(x, data):
25+
if dh.is_float_dtype(x.dtype):
26+
assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))
27+
28+
kw = data.draw(
29+
hh.kwargs(
30+
axis=st.integers(-x.ndim, x.ndim - 1),
31+
descending=st.booleans(),
32+
stable=st.booleans(),
33+
),
34+
label="kw",
35+
)
36+
37+
out = xp.argsort(x, **kw)
38+
39+
assert_default_index("sort", out.dtype)
40+
ph.assert_shape("sort", out.shape, x.shape, **kw)
41+
axis = kw.get("axis", -1)
42+
axes = normalise_axis(axis, x.ndim)
43+
descending = kw.get("descending", False)
44+
scalar_type = dh.get_scalar_type(x.dtype)
45+
for indices in axes_ndindex(x.shape, axes):
46+
elements = [scalar_type(x[idx]) for idx in indices]
47+
indices_order = sorted(range(len(indices)), key=elements.__getitem__)
48+
if descending:
49+
# sorted(..., reverse=descending) doesn't always work
50+
indices_order = reversed(indices_order)
51+
for idx, o in zip(indices, indices_order):
52+
assert_equals("argsort", int, idx, int(out[idx]), o)
1953

2054

2155
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
@@ -55,7 +89,7 @@ def test_sort(x, data):
5589
)
5690
x_indices = [indices[o] for o in indices_order]
5791
for out_idx, x_idx in zip(indices, x_indices):
58-
assert_equals(
92+
assert_equals_(
5993
"sort",
6094
f"x[{x_idx}]",
6195
x[x_idx],

0 commit comments

Comments
 (0)