@@ -46,49 +46,50 @@ def assert_array_ndindex(
46
46
assert out [out_idx ] == x [x_idx ], msg
47
47
48
48
49
- @st .composite
50
- def concat_shapes (draw , shape , axis ):
51
- shape = list (shape )
52
- shape [axis ] = draw (st .integers (1 , MAX_SIDE ))
53
- return tuple (shape )
54
-
55
-
56
49
@given (
57
50
dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
58
- kw = hh . kwargs ( axis = st .none () | st .integers (- MAX_DIMS , MAX_DIMS - 1 ) ),
51
+ _axis = st .none () | st .integers (0 , MAX_DIMS - 1 ),
59
52
data = st .data (),
60
53
)
61
- def test_concat (dtypes , kw , data ):
62
- axis = kw .get ("axis" , 0 )
63
- if axis is None :
54
+ def test_concat (dtypes , _axis , data ):
55
+ if _axis is None :
64
56
shape_strat = hh .shapes ()
57
+ axis_strat = st .none ()
65
58
else :
66
- any_side_axis = axis if axis >= 0 else abs (axis ) - 1
67
- shape_strat = shared_shapes (min_dims = any_side_axis + 1 ).flatmap (
68
- lambda s : concat_shapes (s , any_side_axis )
59
+ base_shape = data .draw (
60
+ hh .shapes (min_dims = _axis + 1 ).map (
61
+ lambda t : t [:_axis ] + (None ,) + t [_axis + 1 :]
62
+ ),
63
+ label = "base shape" ,
64
+ )
65
+ shape_strat = st .integers (0 , MAX_SIDE ).map (
66
+ lambda i : base_shape [:_axis ] + (i ,) + base_shape [_axis + 1 :]
69
67
)
68
+ axis_strat = st .sampled_from ([_axis , _axis - len (base_shape )])
70
69
arrays = []
71
70
for i , dtype in enumerate (dtypes , 1 ):
72
71
x = data .draw (xps .arrays (dtype = dtype , shape = shape_strat ), label = f"x{ i } " )
73
72
arrays .append (x )
73
+ kw = data .draw (
74
+ axis_strat .flatmap (lambda a : hh .specified_kwargs (("axis" , a , 0 ))), label = "kw"
75
+ )
74
76
75
77
out = xp .concat (arrays , ** kw )
76
78
77
79
ph .assert_dtype ("concat" , dtypes , out .dtype )
78
80
79
81
shapes = tuple (x .shape for x in arrays )
80
- axis = kw .get ("axis" , 0 )
81
- if axis is None :
82
+ if _axis is None :
82
83
size = sum (math .prod (s ) for s in shapes )
83
84
shape = (size ,)
84
85
else :
85
86
shape = list (shapes [0 ])
86
87
for other_shape in shapes [1 :]:
87
- shape [axis ] += other_shape [axis ]
88
+ shape [_axis ] += other_shape [_axis ]
88
89
shape = tuple (shape )
89
90
ph .assert_result_shape ("concat" , shapes , out .shape , shape , ** kw )
90
91
91
- if axis is None :
92
+ if _axis is None :
92
93
out_indices = (i for i in range (out .size ))
93
94
for x_num , x in enumerate (arrays , 1 ):
94
95
for x_idx in sh .ndindex (x .shape ):
@@ -102,8 +103,6 @@ def test_concat(dtypes, kw, data):
102
103
** kw ,
103
104
)
104
105
else :
105
- ndim = len (shapes [0 ])
106
- _axis = axis if axis >= 0 else ndim - 1
107
106
out_indices = sh .ndindex (out .shape )
108
107
for idx in sh .axis_ndindex (shapes [0 ], _axis ):
109
108
f_idx = ", " .join (str (i ) if isinstance (i , int ) else ":" for i in idx )
0 commit comments