Skip to content

TST: rework tests for xp_assert_equal #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 132 additions & 139 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Callable
from contextlib import nullcontext
from types import ModuleType
from typing import cast

Expand All @@ -21,160 +20,154 @@
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function

# mypy: disable-error-code=decorated-any
# mypy: disable-error-code="decorated-any, explicit-any"
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false

param_assert_equal_close = pytest.mark.parametrize(
"func",
[
xp_assert_equal,
xp_assert_less,
pytest.param(
xp_assert_close,
marks=pytest.mark.xfail_xp_backend(
Backend.SPARSE, reason="no isdtype", strict=False
),
),
],
)


def test_as_numpy_array(xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3], device=device)
y = as_numpy_array(x, xp=xp)
assert isinstance(y, np.ndarray)


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
func(xp.asarray(0), xp.asarray(0))
func(xp.asarray([1, 2]), xp.asarray([1, 2]))

with pytest.raises(AssertionError, match="shapes do not match"):
func(xp.asarray([0]), xp.asarray([[0]]))

with pytest.raises(AssertionError, match="dtypes do not match"):
func(xp.asarray(0, dtype=xp.float32), xp.asarray(0, dtype=xp.float64))

with pytest.raises(AssertionError):
func(xp.asarray([1, 2]), xp.asarray([1, 3]))

with pytest.raises(AssertionError, match="hello"):
func(xp.asarray([1, 2]), xp.asarray([1, 3]), err_msg="hello")


@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="test other ns vs. numpy")
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
def test_assert_close_equal_less_namespace(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
with pytest.raises(AssertionError, match="namespaces do not match"):
func(xp.asarray(0), np.asarray(0))
with pytest.raises(TypeError, match="Unrecognized array input"):
func(xp.asarray(0), 0)
with pytest.raises(TypeError, match="list is not a supported array type"):
func(xp.asarray([0]), [0])


@param_assert_equal_close
@pytest.mark.parametrize("check_shape", [False, True])
def test_assert_close_equal_less_shape( # type: ignore[explicit-any]
xp: ModuleType,
func: Callable[..., None],
check_shape: bool,
):
context = (
pytest.raises(AssertionError, match="shapes do not match")
if check_shape
else nullcontext()
)
with context:
# note: NaNs are handled by all 3 checks
func(xp.asarray([xp.nan, xp.nan]), xp.asarray(xp.nan), check_shape=check_shape)


@param_assert_equal_close
@pytest.mark.parametrize("check_dtype", [False, True])
def test_assert_close_equal_less_dtype( # type: ignore[explicit-any]
xp: ModuleType,
func: Callable[..., None],
check_dtype: bool,
):
context = (
pytest.raises(AssertionError, match="dtypes do not match")
if check_dtype
else nullcontext()
)
with context:
func(
xp.asarray(xp.nan, dtype=xp.float32),
xp.asarray(xp.nan, dtype=xp.float64),
check_dtype=check_dtype,
)


@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
@pytest.mark.parametrize("check_scalar", [False, True])
def test_assert_close_equal_less_scalar( # type: ignore[explicit-any]
xp: ModuleType,
func: Callable[..., None],
check_scalar: bool,
):
context = (
pytest.raises(AssertionError, match="array-ness does not match")
if check_scalar
else nullcontext()
class TestAssertEqualCloseLess:
pr_assert_close = pytest.param( # pyright: ignore[reportUnannotatedClassAttribute]
xp_assert_close,
marks=pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype"),
)
with context:
func(np.asarray(xp.nan), np.asarray(xp.nan)[()], check_scalar=check_scalar)

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close])
def test_assert_equal_close_basic(self, xp: ModuleType, func: Callable[..., None]):
func(xp.asarray(0), xp.asarray(0))
func(xp.asarray([1, 2]), xp.asarray([1, 2]))

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
def test_assert_close_tolerance(xp: ModuleType):
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.03)
with pytest.raises(AssertionError):
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.01)
with pytest.raises(AssertionError, match="Mismatched elements"):
func(xp.asarray([1, 2]), xp.asarray([2, 1]))

xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=3)
with pytest.raises(AssertionError):
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1)
with pytest.raises(AssertionError, match="hello"):
func(xp.asarray([1, 2]), xp.asarray([2, 1]), err_msg="hello")

@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
def test_shape_dtype(self, xp: ModuleType, func: Callable[..., None]):
with pytest.raises(AssertionError, match="shapes do not match"):
func(xp.asarray([0]), xp.asarray([[0]]))

def test_assert_less_basic(xp: ModuleType):
xp_assert_less(xp.asarray(-1), xp.asarray(0))
xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3]))
with pytest.raises(AssertionError):
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))
with pytest.raises(AssertionError, match="hello"):
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]), err_msg="hello")
with pytest.raises(AssertionError, match="dtypes do not match"):
func(xp.asarray(0, dtype=xp.float32), xp.asarray(0, dtype=xp.float64))


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
"""On Dask and other lazy backends, test that a shape with NaN's or None's
can be compared to a real shape.
"""
a = xp.asarray([1, 2])
a = a[a > 1]

func(a, xp.asarray([2]))
with pytest.raises(AssertionError):
func(a, xp.asarray([2, 3]))
with pytest.raises(AssertionError):
func(a, xp.asarray(2))
with pytest.raises(AssertionError):
func(a, xp.asarray([3]))

# Swap actual and desired
func(xp.asarray([2]), a)
with pytest.raises(AssertionError):
func(xp.asarray([2, 3]), a)
with pytest.raises(AssertionError):
func(xp.asarray(2), a)
with pytest.raises(AssertionError):
func(xp.asarray([3]), a)
@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
@pytest.mark.skip_xp_backend(
Backend.NUMPY_READONLY, reason="test other ns vs. numpy"
)
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
def test_namespace(self, xp: ModuleType, func: Callable[..., None]):
with pytest.raises(AssertionError, match="namespaces do not match"):
func(xp.asarray(0), np.asarray(0))
with pytest.raises(TypeError, match="Unrecognized array input"):
func(xp.asarray(0), 0)
with pytest.raises(TypeError, match="list is not a supported array type"):
func(xp.asarray([0]), [0])

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
a = xp.asarray([1] if func is xp_assert_less else [2])
b = xp.asarray(2)
c = xp.asarray(0)
d = xp.asarray([2, 2])

with pytest.raises(AssertionError, match="shapes do not match"):
func(a, b)
func(a, b, check_shape=False)
with pytest.raises(AssertionError, match="Mismatched elements"):
func(a, c, check_shape=False)
with pytest.raises(AssertionError, match=r"shapes \(1,\), \(2,\) mismatch"):
func(a, d, check_shape=False)

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
def test_check_dtype(self, xp: ModuleType, func: Callable[..., None]):
a = xp.asarray(1 if func is xp_assert_less else 2)
b = xp.asarray(2, dtype=xp.int16)
c = xp.asarray(0, dtype=xp.int16)

with pytest.raises(AssertionError, match="dtypes do not match"):
func(a, b)
func(a, b, check_dtype=False)
with pytest.raises(AssertionError, match="Mismatched elements"):
func(a, c, check_dtype=False)

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
@pytest.mark.xfail_xp_backend(
Backend.SPARSE, reason="sparse [()] returns np.generic"
)
def test_check_scalar(
self, xp: ModuleType, library: Backend, func: Callable[..., None]
):
a = xp.asarray(1 if func is xp_assert_less else 2)
b = xp.asarray(2)[()] # Note: only makes a difference on NumPy
c = xp.asarray(0)

func(a, b)
if library.like(Backend.NUMPY):
with pytest.raises(AssertionError, match="array-ness does not match"):
func(a, b, check_scalar=True)
else:
func(a, b, check_scalar=True)
with pytest.raises(AssertionError, match="Mismatched elements"):
func(a, c, check_scalar=True)

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.parametrize("dtype", ["int64", "float64"])
def test_assert_close_tolerance(self, dtype: str, xp: ModuleType):
a = xp.asarray([100], dtype=getattr(xp, dtype))
b = xp.asarray([102], dtype=getattr(xp, dtype))

with pytest.raises(AssertionError, match="Mismatched elements"):
xp_assert_close(a, b)

xp_assert_close(a, b, rtol=0.03)
with pytest.raises(AssertionError, match="Mismatched elements"):
xp_assert_close(a, b, rtol=0.01)

xp_assert_close(a, b, atol=3)
with pytest.raises(AssertionError, match="Mismatched elements"):
xp_assert_close(a, b, atol=1)

def test_assert_less(self, xp: ModuleType):
xp_assert_less(xp.asarray(-1), xp.asarray(0))
xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3]))
with pytest.raises(AssertionError, match="Mismatched elements"):
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
"""On Dask and other lazy backends, test that a shape with NaN's or None's
can be compared to a real shape.
"""
# actual has shape=(None, )
a = xp.asarray([1] if func is xp_assert_less else [2])
a = a[a > 0]

func(a, xp.asarray([2]))
with pytest.raises(AssertionError, match="shapes do not match"):
func(a, xp.asarray(2))
with pytest.raises(AssertionError, match="shapes do not match"):
func(a, xp.asarray([2, 3]))
with pytest.raises(AssertionError, match="Mismatched elements"):
func(a, xp.asarray([0]))

# desired has shape=(None, )
a = xp.asarray([3] if func is xp_assert_less else [2])
a = a[a > 0]

func(xp.asarray([2]), a)
with pytest.raises(AssertionError, match="shapes do not match"):
func(xp.asarray(2), a)
with pytest.raises(AssertionError, match="shapes do not match"):
func(xp.asarray([2, 3]), a)
with pytest.raises(AssertionError, match="Mismatched elements"):
func(xp.asarray([4]), a)


def good_lazy(x: Array) -> Array:
Expand Down