Skip to content

Commit 314e1be

Browse files
authored
Merge pull request #68 from honno/argsort-desc
Fix `test_argsort` for unstable and descending scenarios
2 parents cad86ef + 7e039c1 commit 314e1be

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

array_api_tests/test_sorting_functions.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,17 @@
55
from hypothesis import strategies as st
66
from hypothesis.control import assume
77

8-
from .typing import Scalar, ScalarType, Shape
9-
108
from . import _array_module as xp
119
from . import dtype_helpers as dh
1210
from . import hypothesis_helpers as hh
1311
from . import pytest_helpers as ph
1412
from . import shape_helpers as sh
1513
from . import xps
14+
from .typing import Scalar, Shape
1615

1716

1817
def assert_scalar_in_set(
1918
func_name: str,
20-
type_: ScalarType,
2119
idx: Shape,
2220
out: Scalar,
2321
set_: Set[Scalar],
@@ -62,12 +60,13 @@ def test_argsort(x, data):
6260
scalar_type = dh.get_scalar_type(x.dtype)
6361
for indices in sh.axes_ndindex(x.shape, axes):
6462
elements = [scalar_type(x[idx]) for idx in indices]
65-
orders = sorted(range(len(elements)), key=elements.__getitem__)
66-
if kw.get("descending", False):
67-
orders = reversed(orders)
63+
orders = list(range(len(elements)))
64+
sorders = sorted(
65+
orders, key=elements.__getitem__, reverse=kw.get("descending", False)
66+
)
6867
if kw.get("stable", True):
69-
for idx, o in zip(indices, orders):
70-
ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o)
68+
for idx, o in zip(indices, sorders):
69+
ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o, **kw)
7170
else:
7271
idx_elements = dict(zip(indices, elements))
7372
idx_orders = dict(zip(indices, orders))
@@ -76,17 +75,17 @@ def test_argsort(x, data):
7675
element_orders[e] = [
7776
idx_orders[idx] for idx in indices if idx_elements[idx] == e
7877
]
79-
for idx, e in zip(indices, elements):
80-
o = int(out[idx])
78+
selements = [elements[o] for o in sorders]
79+
for idx, e in zip(indices, selements):
8180
expected_orders = element_orders[e]
81+
out_o = int(out[idx])
8282
if len(expected_orders) == 1:
83-
expected_order = expected_orders[0]
8483
ph.assert_scalar_equals(
85-
"argsort", int, idx, o, expected_order, **kw
84+
"argsort", int, idx, out_o, expected_orders[0], **kw
8685
)
8786
else:
8887
assert_scalar_in_set(
89-
"argsort", int, idx, o, set(expected_orders), **kw
88+
"argsort", idx, out_o, set(expected_orders), **kw
9089
)
9190

9291

@@ -127,6 +126,7 @@ def test_sort(x, data):
127126
)
128127
for out_idx, o in zip(indices, orders):
129128
x_idx = indices[o]
129+
# TODO: error message when unstable should not imply just one idx
130130
ph.assert_0d_equals(
131131
"sort",
132132
f"x[{x_idx}]",

0 commit comments

Comments
 (0)