Skip to content

Commit 0bdf4a5

Browse files
committed
Use correct normalised axis for test_concat
1 parent 27bebe6 commit 0bdf4a5

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def test_concat(dtypes, kw, data):
6363
if axis is None:
6464
shape_strat = hh.shapes()
6565
else:
66-
_axis = axis if axis >= 0 else abs(axis) - 1
67-
shape_strat = shared_shapes(min_dims=_axis + 1).flatmap(
68-
lambda s: concat_shapes(s, axis)
66+
any_side_axis = axis if axis >= 0 else abs(axis) - 1
67+
shape_strat = shared_shapes(min_dims=any_side_axis + 1).flatmap(
68+
lambda s: concat_shapes(s, any_side_axis)
6969
)
7070
arrays = []
7171
for i, dtype in enumerate(dtypes, 1):
@@ -102,6 +102,8 @@ def test_concat(dtypes, kw, data):
102102
**kw,
103103
)
104104
else:
105+
ndim = len(shapes[0])
106+
_axis = axis if axis >= 0 else ndim - 1
105107
out_indices = sh.ndindex(out.shape)
106108
for idx in sh.axis_ndindex(shapes[0], _axis):
107109
f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)

0 commit comments

Comments
 (0)