Skip to content

Commit 5c1aa45

Browse files
committed
Update _test_stacks to use updated ndindex behavior
This requires Quansight-Labs/ndindex#155 which is not yet released.
1 parent 3501116 commit 5c1aa45

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

array_api_tests/test_linalg.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def assert_equal(x, y, msg_extra=None):
6060

6161
def _test_stacks(f, *args, res=None, dims=2, true_val=None,
6262
matrix_axes=(-2, -1),
63+
res_axes=None,
6364
assert_equal=assert_equal, **kw):
6465
"""
6566
Test that f(*args, **kw) maps across stacks of matrices
@@ -84,7 +85,10 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
8485

8586
# Assume the result is stacked along the last 'dims' axes of matrix_axes.
8687
# This holds for all the functions tested in this file
87-
res_axes = matrix_axes[::-1][:dims]
88+
if res_axes is None:
89+
if not isinstance(matrix_axes, tuple) and all(isinstance(x, int) for x in matrix_axes):
90+
raise ValueError("res_axes must be specified if matrix_axes is not a tuple of integers")
91+
res_axes = matrix_axes[::-1][:dims]
8892

8993
for (x_idxes, (res_idx,)) in zip(
9094
iter_indices(*shapes, skip_axes=matrix_axes),
@@ -330,10 +334,12 @@ def test_matmul(x1, x2):
330334
assert res.shape == ()
331335
elif len(x1.shape) == 1:
332336
assert res.shape == x2.shape[:-2] + x2.shape[-1:]
333-
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1)
337+
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1,
338+
matrix_axes=[(0,), (-2, -1)], res_axes=[-1])
334339
elif len(x2.shape) == 1:
335340
assert res.shape == x1.shape[:-1]
336-
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1)
341+
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1,
342+
matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
337343
else:
338344
stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
339345
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
@@ -546,10 +552,11 @@ def test_solve(x1, x2):
546552
# TODO: This requires an upstream fix to ndindex
547553
# (https://github.com/Quansight-Labs/ndindex/pull/131)
548554

549-
# if x2.ndim == 1:
550-
# _test_stacks(linalg.solve, x1, x2, res=res, dims=1)
551-
# else:
552-
# _test_stacks(linalg.solve, x1, x2, res=res, dims=2)
555+
if x2.ndim == 1:
556+
_test_stacks(linalg.solve, x1, x2, res=res, dims=1,
557+
matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
558+
else:
559+
_test_stacks(linalg.solve, x1, x2, res=res, dims=2)
553560

554561
@pytest.mark.xp_extension('linalg')
555562
@given(

0 commit comments

Comments
 (0)