Skip to content

Commit 8633155

Browse files
asmeurerhonno
authored andcommitted
Fix _test_stacks helper to correctly get the stacked axes for res
This logic may not be correctly general but it works for every function here so far (not sure yet if it will be correct for tensordot).
1 parent aae17fc commit 8633155

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

array_api_tests/test_linalg.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,17 @@
3737

3838
pytestmark = pytest.mark.ci
3939

40-
41-
4240
# Standin strategy for not yet implemented tests
4341
todo = none()
4442

45-
def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1),
43+
def _test_stacks(f, *args, res=None, dims=2, true_val=None,
44+
matrix_axes=(-2, -1),
4645
assert_equal=assert_exactly_equal, **kw):
4746
"""
4847
Test that f(*args, **kw) maps across stacks of matrices
4948
50-
dims is the number of dimensions f(*args) should have for a single n x m
51-
matrix stack.
49+
dims is the number of dimensions f(*args, *kw) should have for a single n
50+
x m matrix stack.
5251
5352
matrix_axes are the axes along which matrices (or vectors) are stacked in
5453
the input.
@@ -65,9 +64,13 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1)
6564

6665
shapes = [x.shape for x in args]
6766

67+
# Assume the result is stacked along the last 'dims' axes of matrix_axes.
68+
# This holds for all the functions tested in this file
69+
res_axes = matrix_axes[::-1][:dims]
70+
6871
for (x_idxes, (res_idx,)) in zip(
6972
iter_indices(*shapes, skip_axes=matrix_axes),
70-
iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))):
73+
iter_indices(res.shape, skip_axes=res_axes)):
7174
x_idxes = [x_idx.raw for x_idx in x_idxes]
7275
res_idx = res_idx.raw
7376

0 commit comments

Comments
 (0)