9
9
from types import ModuleType
10
10
from typing import cast
11
11
12
+ import numpy as np
12
13
import pytest
13
14
14
15
from ._utils ._compat import (
15
16
array_namespace ,
16
17
is_array_api_strict_namespace ,
17
18
is_cupy_namespace ,
18
19
is_dask_namespace ,
20
+ is_numpy_namespace ,
19
21
is_pydata_sparse_namespace ,
20
22
is_torch_namespace ,
21
23
)
25
27
26
28
27
29
def _check_ns_shape_dtype (
28
- actual : Array , desired : Array
30
+ actual : Array ,
31
+ desired : Array ,
32
+ check_dtype : bool ,
33
+ check_shape : bool ,
34
+ check_scalar : bool ,
29
35
) -> ModuleType : # numpydoc ignore=RT03
30
36
"""
31
37
Assert that namespace, shape and dtype of the two arrays match.
@@ -47,43 +53,64 @@ def _check_ns_shape_dtype(
47
53
msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
48
54
assert actual_xp == desired_xp , msg
49
55
50
- actual_shape = actual .shape
51
- desired_shape = desired .shape
52
- if is_dask_namespace (desired_xp ):
53
- # Dask uses nan instead of None for unknown shapes
54
- if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
55
- actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
56
- if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
57
- desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
58
-
59
- msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
60
- assert actual_shape == desired_shape , msg
61
-
62
- msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
63
- assert actual .dtype == desired .dtype , msg
56
+ if check_shape :
57
+ actual_shape = actual .shape
58
+ desired_shape = desired .shape
59
+ if is_dask_namespace (desired_xp ):
60
+ # Dask uses nan instead of None for unknown shapes
61
+ if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
62
+ actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
63
+ if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
64
+ desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
65
+
66
+ msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
67
+ assert actual_shape == desired_shape , msg
68
+
69
+ if check_dtype :
70
+ msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
71
+ assert actual .dtype == desired .dtype , msg
72
+
73
+ if is_numpy_namespace (actual_xp ) and check_scalar :
74
+ # only NumPy distinguishes between scalars and arrays; we do if check_scalar.
75
+ _msg = (
76
+ "array-ness does not match:\n Actual: "
77
+ f"{ type (actual )} \n Desired: { type (desired )} "
78
+ )
79
+ assert (np .isscalar (actual ) and np .isscalar (desired )) or (
80
+ not np .isscalar (actual ) and not np .isscalar (desired )
81
+ ), _msg
64
82
65
83
return desired_xp
66
84
67
85
68
86
def _prepare_for_test (array : Array , xp : ModuleType ) -> Array :
69
87
"""
70
- Ensure that the array can be compared with xp.testing or np.testing.
88
+ Ensure that the array can be compared with np.testing.
71
89
72
90
This involves transferring it from GPU to CPU memory, densifying it, etc.
73
91
"""
74
92
if is_torch_namespace (xp ):
75
- return array .cpu () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
93
+ return np . asarray ( array .cpu ()) # type: ignore[attr-defined, return-value ] # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportReturnType ]
76
94
if is_pydata_sparse_namespace (xp ):
77
95
return array .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
78
96
if is_array_api_strict_namespace (xp ):
79
97
# Note: we deliberately did not add a `.to_device` method in _typing.pyi
80
98
# even if it is required by the standard as many backends don't support it
81
99
return array .to_device (xp .Device ("CPU_DEVICE" )) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82
- # Note: nothing to do for CuPy, because it uses a bespoke test function
100
+ if is_cupy_namespace (xp ):
101
+ return xp .asnumpy (array )
83
102
return array
84
103
85
104
86
- def xp_assert_equal (actual : Array , desired : Array , err_msg : str = "" ) -> None :
105
+ def xp_assert_equal (
106
+ actual : Array ,
107
+ desired : Array ,
108
+ * ,
109
+ err_msg : str = "" ,
110
+ check_dtype : bool = True ,
111
+ check_shape : bool = True ,
112
+ check_scalar : bool = False ,
113
+ ) -> None :
87
114
"""
88
115
Array-API compatible version of `np.testing.assert_array_equal`.
89
116
@@ -95,34 +122,21 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
95
122
The expected array (typically hardcoded).
96
123
err_msg : str, optional
97
124
Error message to display on failure.
125
+ check_dtype, check_shape : bool, default: True
126
+ Whether to check agreement between actual and desired dtypes and shapes
127
+ check_scalar : bool, default: False
128
+ NumPy only: whether to check agreement between actual and desired types -
129
+ 0d array vs scalar.
98
130
99
131
See Also
100
132
--------
101
133
xp_assert_close : Similar function for inexact equality checks.
102
134
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
103
135
"""
104
- xp = _check_ns_shape_dtype (actual , desired )
136
+ xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
105
137
actual = _prepare_for_test (actual , xp )
106
138
desired = _prepare_for_test (desired , xp )
107
-
108
- if is_cupy_namespace (xp ):
109
- xp .testing .assert_array_equal (actual , desired , err_msg = err_msg )
110
- elif is_torch_namespace (xp ):
111
- # PyTorch recommends using `rtol=0, atol=0` like this
112
- # to test for exact equality
113
- xp .testing .assert_close (
114
- actual ,
115
- desired ,
116
- rtol = 0 ,
117
- atol = 0 ,
118
- equal_nan = True ,
119
- check_dtype = False ,
120
- msg = err_msg or None ,
121
- )
122
- else :
123
- import numpy as np # pylint: disable=import-outside-toplevel
124
-
125
- np .testing .assert_array_equal (actual , desired , err_msg = err_msg )
139
+ np .testing .assert_array_equal (actual , desired , err_msg = err_msg )
126
140
127
141
128
142
def xp_assert_close (
@@ -132,6 +146,9 @@ def xp_assert_close(
132
146
rtol : float | None = None ,
133
147
atol : float = 0 ,
134
148
err_msg : str = "" ,
149
+ check_dtype : bool = True ,
150
+ check_shape : bool = True ,
151
+ check_scalar : bool = False ,
135
152
) -> None :
136
153
"""
137
154
Array-API compatible version of `np.testing.assert_allclose`.
@@ -148,6 +165,11 @@ def xp_assert_close(
148
165
Absolute tolerance. Default: 0.
149
166
err_msg : str, optional
150
167
Error message to display on failure.
168
+ check_dtype, check_shape : bool, default: True
169
+ Whether to check agreement between actual and desired dtypes and shapes
170
+ check_scalar : bool, default: False
171
+ NumPy only: whether to check agreement between actual and desired types -
172
+ 0d array vs scalar.
151
173
152
174
See Also
153
175
--------
@@ -159,7 +181,7 @@ def xp_assert_close(
159
181
-----
160
182
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
161
183
"""
162
- xp = _check_ns_shape_dtype (actual , desired )
184
+ xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
163
185
164
186
floating = xp .isdtype (actual .dtype , ("real floating" , "complex floating" ))
165
187
if rtol is None and floating :
@@ -173,26 +195,15 @@ def xp_assert_close(
173
195
actual = _prepare_for_test (actual , xp )
174
196
desired = _prepare_for_test (desired , xp )
175
197
176
- if is_cupy_namespace (xp ):
177
- xp .testing .assert_allclose (
178
- actual , desired , rtol = rtol , atol = atol , err_msg = err_msg
179
- )
180
- elif is_torch_namespace (xp ):
181
- xp .testing .assert_close (
182
- actual , desired , rtol = rtol , atol = atol , equal_nan = True , msg = err_msg or None
183
- )
184
- else :
185
- import numpy as np # pylint: disable=import-outside-toplevel
186
-
187
- # JAX/Dask arrays work directly with `np.testing`
188
- assert isinstance (rtol , float )
189
- np .testing .assert_allclose ( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
190
- actual , # pyright: ignore[reportArgumentType]
191
- desired , # pyright: ignore[reportArgumentType]
192
- rtol = rtol ,
193
- atol = atol ,
194
- err_msg = err_msg ,
195
- )
198
+ # JAX/Dask arrays work directly with `np.testing`
199
+ assert isinstance (rtol , float )
200
+ np .testing .assert_allclose ( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
201
+ actual , # pyright: ignore[reportArgumentType]
202
+ desired , # pyright: ignore[reportArgumentType]
203
+ rtol = rtol ,
204
+ atol = atol ,
205
+ err_msg = err_msg ,
206
+ )
196
207
197
208
198
209
def xfail (request : pytest .FixtureRequest , reason : str ) -> None :
0 commit comments