Skip to content

Commit 5b44997

Browse files
committed
Fix test_sort using wrong axis iteration
1 parent f108941 commit 5b44997

File tree

3 files changed

+19
-22
lines changed

3 files changed

+19
-22
lines changed

array_api_tests/meta/test_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,14 @@ def test_axis_ndindex(shape, axis, expected):
5656
@pytest.mark.parametrize(
5757
"shape, axes, expected",
5858
[
59-
((), (), [((),)]),
59+
((), (), [[()]]),
60+
((1,), (0,), [[(0,)]]),
6061
(
6162
(2, 2),
6263
(0,),
6364
[
64-
((0, 0), (1, 0)),
65-
((0, 1), (1, 1)),
65+
[(0, 0), (1, 0)],
66+
[(0, 1), (1, 1)],
6667
],
6768
),
6869
],

array_api_tests/test_sorting.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from hypothesis.control import assume
44

55
from . import _array_module as xp
6-
from . import array_helpers as ah
76
from . import dtype_helpers as dh
87
from . import hypothesis_helpers as hh
98
from . import pytest_helpers as ph
109
from . import xps
11-
from .test_manipulation_functions import assert_equals, axis_ndindex
10+
from .test_manipulation_functions import assert_equals
11+
from .test_statistical_functions import axes_ndindex, normalise_axis
1212

1313

1414
# TODO: generate kwargs
@@ -45,25 +45,21 @@ def test_sort(x, data):
4545
ph.assert_dtype("sort", out.dtype, x.dtype)
4646
ph.assert_shape("sort", out.shape, x.shape, **kw)
4747
axis = kw.get("axis", -1)
48-
_axis = axis if axis >= 0 else x.ndim + axis
48+
axes = normalise_axis(axis, x.ndim)
4949
descending = kw.get("descending", False)
5050
scalar_type = dh.get_scalar_type(x.dtype)
51-
for idx in axis_ndindex(x.shape, _axis):
52-
f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
53-
indexed_x = x[idx]
54-
indexed_out = out[idx]
55-
out_indices = list(ah.ndindex(indexed_x.shape))
56-
elements = [scalar_type(indexed_x[idx2]) for idx2 in out_indices]
51+
for indices in axes_ndindex(x.shape, axes):
52+
elements = [scalar_type(x[idx]) for idx in indices]
5753
indices_order = sorted(
58-
range(len(out_indices)), key=elements.__getitem__, reverse=descending
54+
range(len(indices)), key=elements.__getitem__, reverse=descending
5955
)
60-
x_indices = [out_indices[o] for o in indices_order]
61-
for out_idx, x_idx in zip(out_indices, x_indices):
56+
x_indices = [indices[o] for o in indices_order]
57+
for out_idx, x_idx in zip(indices, x_indices):
6258
assert_equals(
6359
"sort",
64-
f"x[{f_idx}][{x_idx}]",
65-
indexed_x[x_idx],
66-
f"out[{f_idx}][{out_idx}]",
67-
indexed_out[out_idx],
60+
f"x[{x_idx}]",
61+
x[x_idx],
62+
f"out[{out_idx}]",
63+
out[out_idx],
6864
**kw,
6965
)

array_api_tests/test_statistical_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import Iterator, Optional, Tuple, Union
3+
from typing import Iterator, List, Optional, Tuple, Union
44

55
from hypothesis import assume, given
66
from hypothesis import strategies as st
@@ -38,7 +38,7 @@ def normalise_axis(
3838
return axes
3939

4040

41-
def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, ...]]:
41+
def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
4242
"""Generate indices that index all elements except in `axes` dimensions"""
4343
base_indices = []
4444
axes_indices = []
@@ -58,7 +58,7 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, .
5858
idx[axis] = base_idx[axis]
5959
idx = tuple(idx)
6060
indices.append(idx)
61-
yield tuple(indices)
61+
yield list(indices)
6262

6363

6464
def assert_keepdimable_shape(

0 commit comments

Comments
 (0)