37
37
38
38
pytestmark = pytest .mark .ci
39
39
40
-
41
-
42
40
# Standin strategy for not yet implemented tests
43
41
todo = none ()
44
42
45
- def _test_stacks (f , * args , res = None , dims = 2 , true_val = None , matrix_axes = (- 2 , - 1 ),
43
+ def _test_stacks (f , * args , res = None , dims = 2 , true_val = None ,
44
+ matrix_axes = (- 2 , - 1 ),
46
45
assert_equal = assert_exactly_equal , ** kw ):
47
46
"""
48
47
Test that f(*args, **kw) maps across stacks of matrices
49
48
50
- dims is the number of dimensions f(*args) should have for a single n x m
51
- matrix stack.
49
+ dims is the number of dimensions f(*args, *kw ) should have for a single n
50
+ x m matrix stack.
52
51
53
52
matrix_axes are the axes along which matrices (or vectors) are stacked in
54
53
the input.
@@ -65,9 +64,13 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1)
65
64
66
65
shapes = [x .shape for x in args ]
67
66
67
+ # Assume the result is stacked along the last 'dims' axes of matrix_axes.
68
+ # This holds for all the functions tested in this file
69
+ res_axes = matrix_axes [::- 1 ][:dims ]
70
+
68
71
for (x_idxes , (res_idx ,)) in zip (
69
72
iter_indices (* shapes , skip_axes = matrix_axes ),
70
- iter_indices (res .shape , skip_axes = tuple ( range ( - dims , 0 )) )):
73
+ iter_indices (res .shape , skip_axes = res_axes )):
71
74
x_idxes = [x_idx .raw for x_idx in x_idxes ]
72
75
res_idx = res_idx .raw
73
76
0 commit comments