@@ -60,6 +60,7 @@ def assert_equal(x, y, msg_extra=None):
60
60
61
61
def _test_stacks (f , * args , res = None , dims = 2 , true_val = None ,
62
62
matrix_axes = (- 2 , - 1 ),
63
+ res_axes = None ,
63
64
assert_equal = assert_equal , ** kw ):
64
65
"""
65
66
Test that f(*args, **kw) maps across stacks of matrices
@@ -84,7 +85,10 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
84
85
85
86
# Assume the result is stacked along the last 'dims' axes of matrix_axes.
86
87
# This holds for all the functions tested in this file
87
- res_axes = matrix_axes [::- 1 ][:dims ]
88
+ if res_axes is None :
89
+ if not isinstance (matrix_axes , tuple ) and all (isinstance (x , int ) for x in matrix_axes ):
90
+ raise ValueError ("res_axes must be specified if matrix_axes is not a tuple of integers" )
91
+ res_axes = matrix_axes [::- 1 ][:dims ]
88
92
89
93
for (x_idxes , (res_idx ,)) in zip (
90
94
iter_indices (* shapes , skip_axes = matrix_axes ),
@@ -330,10 +334,12 @@ def test_matmul(x1, x2):
330
334
assert res .shape == ()
331
335
elif len (x1 .shape ) == 1 :
332
336
assert res .shape == x2 .shape [:- 2 ] + x2 .shape [- 1 :]
333
- _test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 )
337
+ _test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 ,
338
+ matrix_axes = [(0 ,), (- 2 , - 1 )], res_axes = [- 1 ])
334
339
elif len (x2 .shape ) == 1 :
335
340
assert res .shape == x1 .shape [:- 1 ]
336
- _test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 )
341
+ _test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 ,
342
+ matrix_axes = [(- 2 , - 1 ), (0 ,)], res_axes = [- 1 ])
337
343
else :
338
344
stack_shape = sh .broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ])
339
345
assert res .shape == stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ])
@@ -546,10 +552,11 @@ def test_solve(x1, x2):
546
552
# TODO: This requires an upstream fix to ndindex
547
553
# (https://github.com/Quansight-Labs/ndindex/pull/131)
548
554
549
- # if x2.ndim == 1:
550
- # _test_stacks(linalg.solve, x1, x2, res=res, dims=1)
551
- # else:
552
- # _test_stacks(linalg.solve, x1, x2, res=res, dims=2)
555
+ if x2 .ndim == 1 :
556
+ _test_stacks (linalg .solve , x1 , x2 , res = res , dims = 1 ,
557
+ matrix_axes = [(- 2 , - 1 ), (0 ,)], res_axes = [- 1 ])
558
+ else :
559
+ _test_stacks (linalg .solve , x1 , x2 , res = res , dims = 2 )
553
560
554
561
@pytest .mark .xp_extension ('linalg' )
555
562
@given (
0 commit comments