Skip to content

Rework prepare_for_test #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 132 additions & 84 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,33 @@

import math
from types import ModuleType
from typing import cast
from typing import Any, cast

import numpy as np
import pytest

from ._utils._compat import (
array_namespace,
is_array_api_strict_namespace,
is_cupy_namespace,
is_dask_namespace,
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
to_device,
)
from ._utils._typing import Array
from ._utils._typing import Array, Device

__all__ = ["xp_assert_close", "xp_assert_equal"]
__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"]


def _check_ns_shape_dtype(
actual: Array, desired: Array
actual: Array,
desired: Array,
check_dtype: bool,
check_shape: bool,
check_scalar: bool,
) -> ModuleType: # numpydoc ignore=RT03
"""
Assert that namespace, shape and dtype of the two arrays match.
Expand All @@ -47,43 +55,67 @@ def _check_ns_shape_dtype(
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
assert actual_xp == desired_xp, msg

actual_shape = actual.shape
desired_shape = desired.shape
if is_dask_namespace(desired_xp):
# Dask uses nan instead of None for unknown shapes
if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]

msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
assert actual_shape == desired_shape, msg

msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
assert actual.dtype == desired.dtype, msg
if check_shape:
actual_shape = actual.shape
desired_shape = desired.shape
if is_dask_namespace(desired_xp):
# Dask uses nan instead of None for unknown shapes
if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]

msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
assert actual_shape == desired_shape, msg

if check_dtype:
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
assert actual.dtype == desired.dtype, msg

if is_numpy_namespace(actual_xp) and check_scalar:
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
_msg = (
"array-ness does not match:\n Actual: "
f"{type(actual)}\n Desired: {type(desired)}"
)
assert np.isscalar(actual) == np.isscalar(desired), _msg

return desired_xp


def _prepare_for_test(array: Array, xp: ModuleType) -> Array:
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
"""
Ensure that the array can be compared with xp.testing or np.testing.

This involves transferring it from GPU to CPU memory, densifying it, etc.
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
"""
if is_torch_namespace(xp):
return array.cpu() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
if is_cupy_namespace(xp):
return xp.asnumpy(array)
if is_pydata_sparse_namespace(xp):
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]

if is_torch_namespace(xp):
array = to_device(array, "cpu")
if is_array_api_strict_namespace(xp):
# Note: we deliberately did not add a `.to_device` method in _typing.pyi
# even if it is required by the standard as many backends don't support it
return array.to_device(xp.Device("CPU_DEVICE")) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
# Note: nothing to do for CuPy, because it uses a bespoke test function
return array
cpu: Device = xp.Device("CPU_DEVICE")
array = to_device(array, cpu)
if is_jax_namespace(xp):
import jax

# Note: only needed if the transfer guard is enabled
cpu = cast(Device, jax.devices("cpu")[0])
array = to_device(array, cpu)

return np.asarray(array)

def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:

def xp_assert_equal(
actual: Array,
desired: Array,
*,
err_msg: str = "",
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
) -> None:
"""
Array-API compatible version of `np.testing.assert_array_equal`.

Expand All @@ -95,34 +127,56 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
The expected array (typically hardcoded).
err_msg : str, optional
Error message to display on failure.
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
NumPy only: whether to check agreement between actual and desired types -
0d array vs scalar.

See Also
--------
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(actual, desired)
actual = _prepare_for_test(actual, xp)
desired = _prepare_for_test(desired, xp)
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
actual_np = as_numpy_array(actual, xp=xp)
desired_np = as_numpy_array(desired, xp=xp)
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg)

if is_cupy_namespace(xp):
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
elif is_torch_namespace(xp):
# PyTorch recommends using `rtol=0, atol=0` like this
# to test for exact equality
xp.testing.assert_close(
actual,
desired,
rtol=0,
atol=0,
equal_nan=True,
check_dtype=False,
msg=err_msg or None,
)
else:
import numpy as np # pylint: disable=import-outside-toplevel

np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
def xp_assert_less(
x: Array,
y: Array,
*,
err_msg: str = "",
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
) -> None:
"""
Array-API compatible version of `np.testing.assert_array_less`.

Parameters
----------
x, y : Array
The arrays to compare according to ``x < y`` (elementwise).
err_msg : str, optional
Error message to display on failure.
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
NumPy only: whether to check agreement between actual and desired types -
0d array vs scalar.

See Also
--------
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
x_np = as_numpy_array(x, xp=xp)
y_np = as_numpy_array(y, xp=xp)
np.testing.assert_array_less(x_np, y_np, err_msg=err_msg)


def xp_assert_close(
Expand All @@ -132,6 +186,9 @@ def xp_assert_close(
rtol: float | None = None,
atol: float = 0,
err_msg: str = "",
check_dtype: bool = True,
check_shape: bool = True,
check_scalar: bool = False,
) -> None:
"""
Array-API compatible version of `np.testing.assert_allclose`.
Expand All @@ -148,6 +205,11 @@ def xp_assert_close(
Absolute tolerance. Default: 0.
err_msg : str, optional
Error message to display on failure.
check_dtype, check_shape : bool, default: True
Whether to check agreement between actual and desired dtypes and shapes
check_scalar : bool, default: False
NumPy only: whether to check agreement between actual and desired types -
0d array vs scalar.

See Also
--------
Expand All @@ -159,40 +221,26 @@ def xp_assert_close(
-----
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
"""
xp = _check_ns_shape_dtype(actual, desired)

floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
if rtol is None and floating:
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
# roughly half way between sqrt(eps) and the default for
# `numpy.testing.assert_allclose`, 1e-7
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
elif rtol is None:
rtol = 1e-7

actual = _prepare_for_test(actual, xp)
desired = _prepare_for_test(desired, xp)

if is_cupy_namespace(xp):
xp.testing.assert_allclose(
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
)
elif is_torch_namespace(xp):
xp.testing.assert_close(
actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None
)
else:
import numpy as np # pylint: disable=import-outside-toplevel

# JAX/Dask arrays work directly with `np.testing`
assert isinstance(rtol, float)
np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
actual, # pyright: ignore[reportArgumentType]
desired, # pyright: ignore[reportArgumentType]
rtol=rtol,
atol=atol,
err_msg=err_msg,
)
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)

if rtol is None:
if xp.isdtype(actual.dtype, ("real floating", "complex floating")):
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
# roughly half way between sqrt(eps) and the default for
# `numpy.testing.assert_allclose`, 1e-7
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
else:
rtol = 1e-7
Comment on lines +226 to +233
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To save microseconds in a test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it was an attempt to convince pyright. Regardless, I think it's slightly cleaner this way.


actual_np = as_numpy_array(actual, xp=xp)
desired_np = as_numpy_array(desired, xp=xp)
Comment on lines +235 to +236
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this fixing a pyright issue?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In data-apis#267 all testing becomes done exclusively by numpy.
So it makes sense to push the conversion into this function.
As a cascade effect, this makes type validation on behalf of pyright a lot easier.

np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
actual_np,
desired_np,
rtol=rtol, # pyright: ignore[reportArgumentType]
atol=atol,
err_msg=err_msg,
)


def xfail(
Expand Down
Loading