5
5
from hypothesis import strategies as st
6
6
from hypothesis .control import assume
7
7
8
- from .typing import Scalar , ScalarType , Shape
9
-
10
8
from . import _array_module as xp
11
9
from . import dtype_helpers as dh
12
10
from . import hypothesis_helpers as hh
13
11
from . import pytest_helpers as ph
14
12
from . import shape_helpers as sh
15
13
from . import xps
14
+ from .typing import Scalar , Shape
16
15
17
16
18
17
def assert_scalar_in_set (
19
18
func_name : str ,
20
- type_ : ScalarType ,
21
19
idx : Shape ,
22
20
out : Scalar ,
23
21
set_ : Set [Scalar ],
@@ -62,13 +60,12 @@ def test_argsort(x, data):
62
60
scalar_type = dh .get_scalar_type (x .dtype )
63
61
for indices in sh .axes_ndindex (x .shape , axes ):
64
62
elements = [scalar_type (x [idx ]) for idx in indices ]
65
- orders = sorted (
66
- range (len (elements )),
67
- key = elements .__getitem__ ,
68
- reverse = kw .get ("descending" , False ),
63
+ orders = list (range (len (elements )))
64
+ sorders = sorted (
65
+ orders , key = elements .__getitem__ , reverse = kw .get ("descending" , False )
69
66
)
70
67
if kw .get ("stable" , True ):
71
- for idx , o in zip (indices , orders ):
68
+ for idx , o in zip (indices , sorders ):
72
69
ph .assert_scalar_equals ("argsort" , int , idx , int (out [idx ]), o , ** kw )
73
70
else :
74
71
idx_elements = dict (zip (indices , elements ))
@@ -78,17 +75,17 @@ def test_argsort(x, data):
78
75
element_orders [e ] = [
79
76
idx_orders [idx ] for idx in indices if idx_elements [idx ] == e
80
77
]
81
- for idx , e in zip ( indices , elements ):
82
- o = int ( out [ idx ])
78
+ selements = [ elements [ o ] for o in sorders ]
79
+ for idx , e in zip ( indices , selements ):
83
80
expected_orders = element_orders [e ]
81
+ out_o = int (out [idx ])
84
82
if len (expected_orders ) == 1 :
85
- expected_order = expected_orders [0 ]
86
83
ph .assert_scalar_equals (
87
- "argsort" , int , idx , o , expected_order , ** kw
84
+ "argsort" , int , idx , out_o , expected_orders [ 0 ] , ** kw
88
85
)
89
86
else :
90
87
assert_scalar_in_set (
91
- "argsort" , int , idx , o , set (expected_orders ), ** kw
88
+ "argsort" , idx , out_o , set (expected_orders ), ** kw
92
89
)
93
90
94
91
@@ -129,6 +126,7 @@ def test_sort(x, data):
129
126
)
130
127
for out_idx , o in zip (indices , orders ):
131
128
x_idx = indices [o ]
129
+ # TODO: error message when unstable should not imply just one idx
132
130
ph .assert_0d_equals (
133
131
"sort" ,
134
132
f"x[{ x_idx } ]" ,
0 commit comments