42
42
43
43
pytestmark = pytest .mark .ci
44
44
45
- def assert_equal (x , y ):
45
+ def assert_equal (x , y , msg_extra = None ):
46
+ extra = '' if not msg_extra else f' ({ msg_extra } )'
46
47
if x .dtype in dh .float_dtypes :
47
48
# It's too difficult to do an approximately equal test here because
48
49
# different routines can give completely different answers, and even
@@ -51,10 +52,10 @@ def assert_equal(x, y):
51
52
52
53
# assert_allclose(x, y)
53
54
54
- assert x .shape == y .shape , f"The input arrays do not have the same shapes ({ x .shape } != { y .shape } )"
55
- assert x .dtype == y .dtype , f"The input arrays do not have the same dtype ({ x .dtype } != { y .dtype } )"
55
+ assert x .shape == y .shape , f"The input arrays do not have the same shapes ({ x .shape } != { y .shape } ){ extra } "
56
+ assert x .dtype == y .dtype , f"The input arrays do not have the same dtype ({ x .dtype } != { y .dtype } ){ extra } "
56
57
else :
57
- assert_exactly_equal (x , y )
58
+ assert_exactly_equal (x , y , msg_extra = msg_extra )
58
59
59
60
def _test_stacks (f , * args , res = None , dims = 2 , true_val = None ,
60
61
matrix_axes = (- 2 , - 1 ),
@@ -93,9 +94,10 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
93
94
res_stack = res [res_idx ]
94
95
x_stacks = [x [x_idx ] for x , x_idx in zip (args , x_idxes )]
95
96
decomp_res_stack = f (* x_stacks , ** kw )
96
- assert_equal (res_stack , decomp_res_stack )
97
+ msg_extra = f'{ x_idxes = } , { res_idx = } '
98
+ assert_equal (res_stack , decomp_res_stack , msg_extra )
97
99
if true_val :
98
- assert_equal (decomp_res_stack , true_val (* x_stacks ))
100
+ assert_equal (decomp_res_stack , true_val (* x_stacks ), msg_extra )
99
101
100
102
def _test_namedtuple (res , fields , func_name ):
101
103
"""
0 commit comments