|
7 | 7 | from . import hypothesis_helpers as hh
|
8 | 8 | from . import pytest_helpers as ph
|
9 | 9 | from . import xps
|
10 |
| -from .test_manipulation_functions import assert_equals |
11 |
| -from .test_statistical_functions import axes_ndindex, normalise_axis |
| 10 | +from .test_manipulation_functions import assert_equals as assert_equals_ |
| 11 | +from .test_searching_functions import assert_default_index |
| 12 | +from .test_statistical_functions import assert_equals, axes_ndindex, normalise_axis |
12 | 13 |
|
13 | 14 |
|
14 |
| -# TODO: generate kwargs |
15 |
| -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) |
16 |
| -def test_argsort(x): |
17 |
| - xp.argsort(x) |
18 |
| - # TODO |
| 15 | +# TODO: Test with signed zeros and NaNs (and ignore them somehow) |
| 16 | +@given( |
| 17 | + x=xps.arrays( |
| 18 | + dtype=xps.scalar_dtypes(), |
| 19 | + shape=hh.shapes(min_dims=1, min_side=1), |
| 20 | + elements={"allow_nan": False}, |
| 21 | + ), |
| 22 | + data=st.data(), |
| 23 | +) |
| 24 | +def test_argsort(x, data): |
| 25 | + if dh.is_float_dtype(x.dtype): |
| 26 | + assume(not xp.any(x == -0.0) and not xp.any(x == +0.0)) |
| 27 | + |
| 28 | + kw = data.draw( |
| 29 | + hh.kwargs( |
| 30 | + axis=st.integers(-x.ndim, x.ndim - 1), |
| 31 | + descending=st.booleans(), |
| 32 | + stable=st.booleans(), |
| 33 | + ), |
| 34 | + label="kw", |
| 35 | + ) |
| 36 | + |
| 37 | + out = xp.argsort(x, **kw) |
| 38 | + |
| 39 | + assert_default_index("sort", out.dtype) |
| 40 | + ph.assert_shape("sort", out.shape, x.shape, **kw) |
| 41 | + axis = kw.get("axis", -1) |
| 42 | + axes = normalise_axis(axis, x.ndim) |
| 43 | + descending = kw.get("descending", False) |
| 44 | + scalar_type = dh.get_scalar_type(x.dtype) |
| 45 | + for indices in axes_ndindex(x.shape, axes): |
| 46 | + elements = [scalar_type(x[idx]) for idx in indices] |
| 47 | + indices_order = sorted(range(len(indices)), key=elements.__getitem__) |
| 48 | + if descending: |
| 49 | + # sorted(..., reverse=descending) doesn't always work |
| 50 | + indices_order = reversed(indices_order) |
| 51 | + for idx, o in zip(indices, indices_order): |
| 52 | + assert_equals("argsort", int, idx, int(out[idx]), o) |
19 | 53 |
|
20 | 54 |
|
21 | 55 | # TODO: Test with signed zeros and NaNs (and ignore them somehow)
|
@@ -55,7 +89,7 @@ def test_sort(x, data):
|
55 | 89 | )
|
56 | 90 | x_indices = [indices[o] for o in indices_order]
|
57 | 91 | for out_idx, x_idx in zip(indices, x_indices):
|
58 |
| - assert_equals( |
| 92 | + assert_equals_( |
59 | 93 | "sort",
|
60 | 94 | f"x[{x_idx}]",
|
61 | 95 | x[x_idx],
|
|
0 commit comments