Skip to content

Commit bcce8e6

Browse files
authored
Merge pull request #72 from honno/concat
`test_concat` fixes
2 parents 27bebe6 + 2b70419 commit bcce8e6

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,27 @@ def assert_array_ndindex(
4646
assert out[out_idx] == x[x_idx], msg
4747

4848

49-
@st.composite
50-
def concat_shapes(draw, shape, axis):
51-
shape = list(shape)
52-
shape[axis] = draw(st.integers(1, MAX_SIDE))
53-
return tuple(shape)
54-
55-
5649
@given(
5750
dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
58-
kw=hh.kwargs(axis=st.none() | st.integers(-MAX_DIMS, MAX_DIMS - 1)),
51+
base_shape=hh.shapes(),
5952
data=st.data(),
6053
)
61-
def test_concat(dtypes, kw, data):
54+
def test_concat(dtypes, base_shape, data):
55+
axis_strat = st.none()
56+
ndim = len(base_shape)
57+
if ndim > 0:
58+
axis_strat |= st.integers(-ndim, ndim - 1)
59+
kw = data.draw(
60+
axis_strat.flatmap(lambda a: hh.specified_kwargs(("axis", a, 0))), label="kw"
61+
)
6262
axis = kw.get("axis", 0)
6363
if axis is None:
64+
_axis = None
6465
shape_strat = hh.shapes()
6566
else:
66-
_axis = axis if axis >= 0 else abs(axis) - 1
67-
shape_strat = shared_shapes(min_dims=_axis + 1).flatmap(
68-
lambda s: concat_shapes(s, axis)
67+
_axis = axis if axis >= 0 else len(base_shape) + axis
68+
shape_strat = st.integers(0, MAX_SIDE).map(
69+
lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :]
6970
)
7071
arrays = []
7172
for i, dtype in enumerate(dtypes, 1):
@@ -77,18 +78,17 @@ def test_concat(dtypes, kw, data):
7778
ph.assert_dtype("concat", dtypes, out.dtype)
7879

7980
shapes = tuple(x.shape for x in arrays)
80-
axis = kw.get("axis", 0)
81-
if axis is None:
81+
if _axis is None:
8282
size = sum(math.prod(s) for s in shapes)
8383
shape = (size,)
8484
else:
8585
shape = list(shapes[0])
8686
for other_shape in shapes[1:]:
87-
shape[axis] += other_shape[axis]
87+
shape[_axis] += other_shape[_axis]
8888
shape = tuple(shape)
8989
ph.assert_result_shape("concat", shapes, out.shape, shape, **kw)
9090

91-
if axis is None:
91+
if _axis is None:
9292
out_indices = (i for i in range(out.size))
9393
for x_num, x in enumerate(arrays, 1):
9494
for x_idx in sh.ndindex(x.shape):
@@ -291,7 +291,7 @@ def test_roll(x, data):
291291
else:
292292
axis_strat = st.none()
293293
if x.ndim != 0:
294-
axis_strat = axis_strat | st.integers(-x.ndim, x.ndim - 1)
294+
axis_strat |= st.integers(-x.ndim, x.ndim - 1)
295295
kw_strat = hh.kwargs(axis=axis_strat)
296296
kw = data.draw(kw_strat, label="kw")
297297

0 commit comments

Comments
 (0)