Skip to content

Commit 8f8da4a

Browse files
committed
Generate normalised axis first for test_concat
1 parent 0bdf4a5 commit 8f8da4a

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,49 +46,50 @@ def assert_array_ndindex(
4646
assert out[out_idx] == x[x_idx], msg
4747

4848

49-
@st.composite
50-
def concat_shapes(draw, shape, axis):
51-
shape = list(shape)
52-
shape[axis] = draw(st.integers(1, MAX_SIDE))
53-
return tuple(shape)
54-
55-
5649
@given(
5750
dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
58-
kw=hh.kwargs(axis=st.none() | st.integers(-MAX_DIMS, MAX_DIMS - 1)),
51+
_axis=st.none() | st.integers(0, MAX_DIMS - 1),
5952
data=st.data(),
6053
)
61-
def test_concat(dtypes, kw, data):
62-
axis = kw.get("axis", 0)
63-
if axis is None:
54+
def test_concat(dtypes, _axis, data):
55+
if _axis is None:
6456
shape_strat = hh.shapes()
57+
axis_strat = st.none()
6558
else:
66-
any_side_axis = axis if axis >= 0 else abs(axis) - 1
67-
shape_strat = shared_shapes(min_dims=any_side_axis + 1).flatmap(
68-
lambda s: concat_shapes(s, any_side_axis)
59+
base_shape = data.draw(
60+
hh.shapes(min_dims=_axis + 1).map(
61+
lambda t: t[:_axis] + (None,) + t[_axis + 1 :]
62+
),
63+
label="base shape",
64+
)
65+
shape_strat = st.integers(0, MAX_SIDE).map(
66+
lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :]
6967
)
68+
axis_strat = st.sampled_from([_axis, _axis - len(base_shape)])
7069
arrays = []
7170
for i, dtype in enumerate(dtypes, 1):
7271
x = data.draw(xps.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}")
7372
arrays.append(x)
73+
kw = data.draw(
74+
axis_strat.flatmap(lambda a: hh.specified_kwargs(("axis", a, 0))), label="kw"
75+
)
7476

7577
out = xp.concat(arrays, **kw)
7678

7779
ph.assert_dtype("concat", dtypes, out.dtype)
7880

7981
shapes = tuple(x.shape for x in arrays)
80-
axis = kw.get("axis", 0)
81-
if axis is None:
82+
if _axis is None:
8283
size = sum(math.prod(s) for s in shapes)
8384
shape = (size,)
8485
else:
8586
shape = list(shapes[0])
8687
for other_shape in shapes[1:]:
87-
shape[axis] += other_shape[axis]
88+
shape[_axis] += other_shape[_axis]
8889
shape = tuple(shape)
8990
ph.assert_result_shape("concat", shapes, out.shape, shape, **kw)
9091

91-
if axis is None:
92+
if _axis is None:
9293
out_indices = (i for i in range(out.size))
9394
for x_num, x in enumerate(arrays, 1):
9495
for x_idx in sh.ndindex(x.shape):
@@ -102,8 +103,6 @@ def test_concat(dtypes, kw, data):
102103
**kw,
103104
)
104105
else:
105-
ndim = len(shapes[0])
106-
_axis = axis if axis >= 0 else ndim - 1
107106
out_indices = sh.ndindex(out.shape)
108107
for idx in sh.axis_ndindex(shapes[0], _axis):
109108
f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)

0 commit comments

Comments
 (0)