Skip to content

Commit 72974e0

Browse files
committed
Allow passing an extra assertion message to assert_equal in linalg and assert_exactly_equal
1 parent 02542ff commit 72974e0

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

array_api_tests/array_helpers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,19 +205,21 @@ def notequal(x, y):
205205

206206
return not_equal(x, y)
207207

208-
def assert_exactly_equal(x, y):
208+
def assert_exactly_equal(x, y, msg_extra=None):
209209
"""
210210
Test that the arrays x and y are exactly equal.
211211
212212
If x and y do not have the same shape and dtype, they are not considered
213213
equal.
214214
215215
"""
216-
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})"
216+
extra = '' if not msg_extra else f' ({msg_extra})'
217217

218-
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})"
218+
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}"
219+
220+
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}"
219221

220-
assert all(exactly_equal(x, y)), f"The input arrays have different values ({x!r} != {y!r})"
222+
assert all(exactly_equal(x, y)), f"The input arrays have different values ({x!r} != {y!r}){extra}"
221223

222224
def assert_finite(x):
223225
"""

array_api_tests/test_linalg.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242

4343
pytestmark = pytest.mark.ci
4444

45-
def assert_equal(x, y):
45+
def assert_equal(x, y, msg_extra=None):
46+
extra = '' if not msg_extra else f' ({msg_extra})'
4647
if x.dtype in dh.float_dtypes:
4748
# It's too difficult to do an approximately equal test here because
4849
# different routines can give completely different answers, and even
@@ -51,10 +52,10 @@ def assert_equal(x, y):
5152

5253
# assert_allclose(x, y)
5354

54-
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})"
55-
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})"
55+
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}"
56+
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}"
5657
else:
57-
assert_exactly_equal(x, y)
58+
assert_exactly_equal(x, y, msg_extra=msg_extra)
5859

5960
def _test_stacks(f, *args, res=None, dims=2, true_val=None,
6061
matrix_axes=(-2, -1),
@@ -93,9 +94,10 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
9394
res_stack = res[res_idx]
9495
x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)]
9596
decomp_res_stack = f(*x_stacks, **kw)
96-
assert_equal(res_stack, decomp_res_stack)
97+
msg_extra = f'{x_idxes = }, {res_idx = }'
98+
assert_equal(res_stack, decomp_res_stack, msg_extra)
9799
if true_val:
98-
assert_equal(decomp_res_stack, true_val(*x_stacks))
100+
assert_equal(decomp_res_stack, true_val(*x_stacks), msg_extra)
99101

100102
def _test_namedtuple(res, fields, func_name):
101103
"""

0 commit comments

Comments
 (0)