Skip to content

Commit 1140a1a

Browse files
committed
WIP: xp_assert enhancements
1 parent 6e3ad0f commit 1140a1a

File tree

2 files changed

+122
-61
lines changed

2 files changed

+122
-61
lines changed

src/array_api_extra/_lib/_testing.py

Lines changed: 71 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from types import ModuleType
1010
from typing import cast
1111

12+
import numpy as np
1213
import pytest
1314

1415
from ._utils._compat import (
1516
array_namespace,
1617
is_array_api_strict_namespace,
1718
is_cupy_namespace,
1819
is_dask_namespace,
20+
is_numpy_namespace,
1921
is_pydata_sparse_namespace,
2022
is_torch_namespace,
2123
)
@@ -25,7 +27,11 @@
2527

2628

2729
def _check_ns_shape_dtype(
28-
actual: Array, desired: Array
30+
actual: Array,
31+
desired: Array,
32+
check_dtype: bool,
33+
check_shape: bool,
34+
check_scalar: bool,
2935
) -> ModuleType: # numpydoc ignore=RT03
3036
"""
3137
Assert that namespace, shape and dtype of the two arrays match.
@@ -47,43 +53,64 @@ def _check_ns_shape_dtype(
4753
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
4854
assert actual_xp == desired_xp, msg
4955

50-
actual_shape = actual.shape
51-
desired_shape = desired.shape
52-
if is_dask_namespace(desired_xp):
53-
# Dask uses nan instead of None for unknown shapes
54-
if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
55-
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
56-
if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
57-
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
58-
59-
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
60-
assert actual_shape == desired_shape, msg
61-
62-
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
63-
assert actual.dtype == desired.dtype, msg
56+
if check_shape:
57+
actual_shape = actual.shape
58+
desired_shape = desired.shape
59+
if is_dask_namespace(desired_xp):
60+
# Dask uses nan instead of None for unknown shapes
61+
if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
62+
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
63+
if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
64+
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
65+
66+
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
67+
assert actual_shape == desired_shape, msg
68+
69+
if check_dtype:
70+
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
71+
assert actual.dtype == desired.dtype, msg
72+
73+
if is_numpy_namespace(actual_xp) and check_scalar:
74+
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
75+
_msg = (
76+
"array-ness does not match:\n Actual: "
77+
f"{type(actual)}\n Desired: {type(desired)}"
78+
)
79+
assert (np.isscalar(actual) and np.isscalar(desired)) or (
80+
not np.isscalar(actual) and not np.isscalar(desired)
81+
), _msg
6482

6583
return desired_xp
6684

6785

6886
def _prepare_for_test(array: Array, xp: ModuleType) -> Array:
6987
"""
70-
Ensure that the array can be compared with xp.testing or np.testing.
88+
Ensure that the array can be compared with np.testing.
7189
7290
This involves transferring it from GPU to CPU memory, densifying it, etc.
7391
"""
7492
if is_torch_namespace(xp):
75-
return array.cpu() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
93+
return np.asarray(array.cpu()) # type: ignore[attr-defined, return-value] # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportReturnType]
7694
if is_pydata_sparse_namespace(xp):
7795
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
7896
if is_array_api_strict_namespace(xp):
7997
# Note: we deliberately did not add a `.to_device` method in _typing.pyi
8098
# even if it is required by the standard as many backends don't support it
8199
return array.to_device(xp.Device("CPU_DEVICE")) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82-
# Note: nothing to do for CuPy, because it uses a bespoke test function
100+
if is_cupy_namespace(xp):
101+
return xp.asnumpy(array)
83102
return array
84103

85104

86-
def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
105+
def xp_assert_equal(
106+
actual: Array,
107+
desired: Array,
108+
*,
109+
err_msg: str = "",
110+
check_dtype: bool = True,
111+
check_shape: bool = True,
112+
check_scalar: bool = False,
113+
) -> None:
87114
"""
88115
Array-API compatible version of `np.testing.assert_array_equal`.
89116
@@ -95,34 +122,21 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
95122
The expected array (typically hardcoded).
96123
err_msg : str, optional
97124
Error message to display on failure.
125+
check_dtype, check_shape : bool, default: True
126+
Whether to check agreement between actual and desired dtypes and shapes
127+
check_scalar : bool, default: False
128+
NumPy only: whether to check agreement between actual and desired types -
129+
0d array vs scalar.
98130
99131
See Also
100132
--------
101133
xp_assert_close : Similar function for inexact equality checks.
102134
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
103135
"""
104-
xp = _check_ns_shape_dtype(actual, desired)
136+
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
105137
actual = _prepare_for_test(actual, xp)
106138
desired = _prepare_for_test(desired, xp)
107-
108-
if is_cupy_namespace(xp):
109-
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
110-
elif is_torch_namespace(xp):
111-
# PyTorch recommends using `rtol=0, atol=0` like this
112-
# to test for exact equality
113-
xp.testing.assert_close(
114-
actual,
115-
desired,
116-
rtol=0,
117-
atol=0,
118-
equal_nan=True,
119-
check_dtype=False,
120-
msg=err_msg or None,
121-
)
122-
else:
123-
import numpy as np # pylint: disable=import-outside-toplevel
124-
125-
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
139+
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
126140

127141

128142
def xp_assert_close(
@@ -132,6 +146,9 @@ def xp_assert_close(
132146
rtol: float | None = None,
133147
atol: float = 0,
134148
err_msg: str = "",
149+
check_dtype: bool = True,
150+
check_shape: bool = True,
151+
check_scalar: bool = False,
135152
) -> None:
136153
"""
137154
Array-API compatible version of `np.testing.assert_allclose`.
@@ -148,6 +165,11 @@ def xp_assert_close(
148165
Absolute tolerance. Default: 0.
149166
err_msg : str, optional
150167
Error message to display on failure.
168+
check_dtype, check_shape : bool, default: True
169+
Whether to check agreement between actual and desired dtypes and shapes
170+
check_scalar : bool, default: False
171+
NumPy only: whether to check agreement between actual and desired types -
172+
0d array vs scalar.
151173
152174
See Also
153175
--------
@@ -159,7 +181,7 @@ def xp_assert_close(
159181
-----
160182
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
161183
"""
162-
xp = _check_ns_shape_dtype(actual, desired)
184+
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
163185

164186
floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
165187
if rtol is None and floating:
@@ -173,26 +195,15 @@ def xp_assert_close(
173195
actual = _prepare_for_test(actual, xp)
174196
desired = _prepare_for_test(desired, xp)
175197

176-
if is_cupy_namespace(xp):
177-
xp.testing.assert_allclose(
178-
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
179-
)
180-
elif is_torch_namespace(xp):
181-
xp.testing.assert_close(
182-
actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None
183-
)
184-
else:
185-
import numpy as np # pylint: disable=import-outside-toplevel
186-
187-
# JAX/Dask arrays work directly with `np.testing`
188-
assert isinstance(rtol, float)
189-
np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
190-
actual, # pyright: ignore[reportArgumentType]
191-
desired, # pyright: ignore[reportArgumentType]
192-
rtol=rtol,
193-
atol=atol,
194-
err_msg=err_msg,
195-
)
198+
# JAX/Dask arrays work directly with `np.testing`
199+
assert isinstance(rtol, float)
200+
np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
201+
actual, # pyright: ignore[reportArgumentType]
202+
desired, # pyright: ignore[reportArgumentType]
203+
rtol=rtol,
204+
atol=atol,
205+
err_msg=err_msg,
206+
)
196207

197208

198209
def xfail(request: pytest.FixtureRequest, reason: str) -> None:

tests/test_testing.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Callable
2+
from contextlib import nullcontext
23
from types import ModuleType
34
from typing import cast
45

@@ -24,7 +25,9 @@
2425
xp_assert_equal,
2526
pytest.param(
2627
xp_assert_close,
27-
marks=pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype"),
28+
marks=pytest.mark.xfail_xp_backend(
29+
Backend.SPARSE, reason="no isdtype", strict=False
30+
),
2831
),
2932
],
3033
)
@@ -60,6 +63,53 @@ def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None])
6063
func(xp.asarray([0]), [0])
6164

6265

66+
@param_assert_equal_close
67+
@pytest.mark.parametrize("check_shape", [False, True])
68+
def test_assert_close_equal_shape( # type: ignore[explicit-any]
69+
xp: ModuleType,
70+
func: Callable[..., None],
71+
check_shape: bool,
72+
):
73+
context = (
74+
pytest.raises(AssertionError, match="shapes do not match")
75+
if check_shape
76+
else nullcontext()
77+
)
78+
with context:
79+
func(xp.asarray([0, 0]), xp.asarray(0), check_shape=check_shape)
80+
81+
82+
@param_assert_equal_close
83+
@pytest.mark.parametrize("check_dtype", [False, True])
84+
def test_assert_close_equal_dtype( # type: ignore[explicit-any]
85+
xp: ModuleType,
86+
func: Callable[..., None],
87+
check_dtype: bool,
88+
):
89+
context = (
90+
pytest.raises(AssertionError, match="dtypes do not match")
91+
if check_dtype
92+
else nullcontext()
93+
)
94+
with context:
95+
func(xp.asarray(0.0), xp.asarray(0), check_dtype=check_dtype)
96+
97+
98+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
99+
@pytest.mark.parametrize("check_scalar", [False, True])
100+
def test_assert_close_equal_scalar( # type: ignore[explicit-any]
101+
func: Callable[..., None],
102+
check_scalar: bool,
103+
):
104+
context = (
105+
pytest.raises(AssertionError, match="array-ness does not match")
106+
if check_scalar
107+
else nullcontext()
108+
)
109+
with context:
110+
func(np.asarray(0), np.asarray(0)[()], check_scalar=check_scalar)
111+
112+
63113
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
64114
def test_assert_close_tolerance(xp: ModuleType):
65115
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.03)

0 commit comments

Comments
 (0)