5
5
from hypothesis import strategies as st
6
6
from hypothesis .control import assume
7
7
8
- from .typing import Scalar , ScalarType , Shape
9
-
10
8
from . import _array_module as xp
11
9
from . import dtype_helpers as dh
12
10
from . import hypothesis_helpers as hh
13
11
from . import pytest_helpers as ph
14
12
from . import shape_helpers as sh
15
13
from . import xps
14
+ from .typing import Scalar , Shape
16
15
17
16
18
17
def assert_scalar_in_set (
19
18
func_name : str ,
20
- type_ : ScalarType ,
21
19
idx : Shape ,
22
20
out : Scalar ,
23
21
set_ : Set [Scalar ],
@@ -62,12 +60,13 @@ def test_argsort(x, data):
62
60
scalar_type = dh .get_scalar_type (x .dtype )
63
61
for indices in sh .axes_ndindex (x .shape , axes ):
64
62
elements = [scalar_type (x [idx ]) for idx in indices ]
65
- orders = sorted (range (len (elements )), key = elements .__getitem__ )
66
- if kw .get ("descending" , False ):
67
- orders = reversed (orders )
63
+ orders = list (range (len (elements )))
64
+ sorders = sorted (
65
+ orders , key = elements .__getitem__ , reverse = kw .get ("descending" , False )
66
+ )
68
67
if kw .get ("stable" , True ):
69
- for idx , o in zip (indices , orders ):
70
- ph .assert_scalar_equals ("argsort" , int , idx , int (out [idx ]), o )
68
+ for idx , o in zip (indices , sorders ):
69
+ ph .assert_scalar_equals ("argsort" , int , idx , int (out [idx ]), o , ** kw )
71
70
else :
72
71
idx_elements = dict (zip (indices , elements ))
73
72
idx_orders = dict (zip (indices , orders ))
@@ -76,17 +75,17 @@ def test_argsort(x, data):
76
75
element_orders [e ] = [
77
76
idx_orders [idx ] for idx in indices if idx_elements [idx ] == e
78
77
]
79
- for idx , e in zip ( indices , elements ):
80
- o = int ( out [ idx ])
78
+ selements = [ elements [ o ] for o in sorders ]
79
+ for idx , e in zip ( indices , selements ):
81
80
expected_orders = element_orders [e ]
81
+ out_o = int (out [idx ])
82
82
if len (expected_orders ) == 1 :
83
- expected_order = expected_orders [0 ]
84
83
ph .assert_scalar_equals (
85
- "argsort" , int , idx , o , expected_order , ** kw
84
+ "argsort" , int , idx , out_o , expected_orders [ 0 ] , ** kw
86
85
)
87
86
else :
88
87
assert_scalar_in_set (
89
- "argsort" , int , idx , o , set (expected_orders ), ** kw
88
+ "argsort" , idx , out_o , set (expected_orders ), ** kw
90
89
)
91
90
92
91
@@ -127,6 +126,7 @@ def test_sort(x, data):
127
126
)
128
127
for out_idx , o in zip (indices , orders ):
129
128
x_idx = indices [o ]
129
+ # TODO: error message when unstable should not imply just one idx
130
130
ph .assert_0d_equals (
131
131
"sort" ,
132
132
f"x[{ x_idx } ]" ,
0 commit comments