Skip to content

Commit 7e039c1

Browse files
committed
Fixes unstable argsort() test logic
1 parent 3e6504c commit 7e039c1

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

array_api_tests/test_sorting_functions.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,17 @@
55
from hypothesis import strategies as st
66
from hypothesis.control import assume
77

8-
from .typing import Scalar, ScalarType, Shape
9-
108
from . import _array_module as xp
119
from . import dtype_helpers as dh
1210
from . import hypothesis_helpers as hh
1311
from . import pytest_helpers as ph
1412
from . import shape_helpers as sh
1513
from . import xps
14+
from .typing import Scalar, Shape
1615

1716

1817
def assert_scalar_in_set(
1918
func_name: str,
20-
type_: ScalarType,
2119
idx: Shape,
2220
out: Scalar,
2321
set_: Set[Scalar],
@@ -62,13 +60,12 @@ def test_argsort(x, data):
6260
scalar_type = dh.get_scalar_type(x.dtype)
6361
for indices in sh.axes_ndindex(x.shape, axes):
6462
elements = [scalar_type(x[idx]) for idx in indices]
65-
orders = sorted(
66-
range(len(elements)),
67-
key=elements.__getitem__,
68-
reverse=kw.get("descending", False),
63+
orders = list(range(len(elements)))
64+
sorders = sorted(
65+
orders, key=elements.__getitem__, reverse=kw.get("descending", False)
6966
)
7067
if kw.get("stable", True):
71-
for idx, o in zip(indices, orders):
68+
for idx, o in zip(indices, sorders):
7269
ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o, **kw)
7370
else:
7471
idx_elements = dict(zip(indices, elements))
@@ -78,17 +75,17 @@ def test_argsort(x, data):
7875
element_orders[e] = [
7976
idx_orders[idx] for idx in indices if idx_elements[idx] == e
8077
]
81-
for idx, e in zip(indices, elements):
82-
o = int(out[idx])
78+
selements = [elements[o] for o in sorders]
79+
for idx, e in zip(indices, selements):
8380
expected_orders = element_orders[e]
81+
out_o = int(out[idx])
8482
if len(expected_orders) == 1:
85-
expected_order = expected_orders[0]
8683
ph.assert_scalar_equals(
87-
"argsort", int, idx, o, expected_order, **kw
84+
"argsort", int, idx, out_o, expected_orders[0], **kw
8885
)
8986
else:
9087
assert_scalar_in_set(
91-
"argsort", int, idx, o, set(expected_orders), **kw
88+
"argsort", idx, out_o, set(expected_orders), **kw
9289
)
9390

9491

@@ -129,6 +126,7 @@ def test_sort(x, data):
129126
)
130127
for out_idx, o in zip(indices, orders):
131128
x_idx = indices[o]
129+
# TODO: error message when unstable should not imply just one idx
132130
ph.assert_0d_equals(
133131
"sort",
134132
f"x[{x_idx}]",

0 commit comments

Comments
 (0)