Skip to content

Commit a91beff

Browse files
committed
TST: fix up sparse tests
1 parent 74789e2 commit a91beff

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,11 @@ def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, i
322322
"""
323323
if is_pydata_sparse_namespace(xp):
324324
# No __array_namespace_info__(); no indexing by sparse arrays
325-
return {"boolean indexing": False, "data-dependent shapes": True}
325+
return {
326+
"boolean indexing": False,
327+
"data-dependent shapes": True,
328+
"max dimensions": None,
329+
}
326330
out = xp.__array_namespace_info__().capabilities()
327331
if is_jax_namespace(xp) and out["boolean indexing"]:
328332
# FIXME https://github.com/jax-ml/jax/issues/27418

tests/test_funcs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def test_complex(self, xp: ModuleType):
416416
expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128)
417417
xp_assert_close(actual, expect)
418418

419+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="matmul with nan fillvalue")
419420
def test_empty(self, xp: ModuleType):
420421
with warnings.catch_warnings(record=True):
421422
warnings.simplefilter("always", RuntimeWarning)
@@ -611,7 +612,6 @@ def test_xp(self, xp: ModuleType):
611612
xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]]))
612613

613614

614-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no __array_namespace_info__")
615615
class TestDefaultDType:
616616
def test_basic(self, xp: ModuleType):
617617
assert default_dtype(xp) == xp.empty(0).dtype
@@ -696,6 +696,9 @@ def test_xp(self, xp: ModuleType):
696696
@pytest.mark.filterwarnings( # array_api_strictest
697697
"ignore:invalid value encountered:RuntimeWarning:array_api_strict"
698698
)
699+
@pytest.mark.filterwarnings( # sparse
700+
"ignore:invalid value encountered:RuntimeWarning:sparse"
701+
)
699702
class TestIsClose:
700703
@pytest.mark.parametrize("swap", [False, True])
701704
@pytest.mark.parametrize(
@@ -813,6 +816,7 @@ def test_bool_dtype(self, xp: ModuleType):
813816
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
814817
)
815818

819+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
816820
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
817821
def test_none_shape(self, xp: ModuleType):
818822
a = xp.asarray([1, 5, 0])
@@ -821,6 +825,7 @@ def test_none_shape(self, xp: ModuleType):
821825
a = a[a < 5]
822826
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
823827

828+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
824829
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
825830
def test_none_shape_bool(self, xp: ModuleType):
826831
a = xp.asarray([True, True, False])
@@ -1136,6 +1141,7 @@ def test_xp(self, xp: ModuleType):
11361141

11371142

11381143
class TestSinc:
1144+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no linspace")
11391145
def test_simple(self, xp: ModuleType):
11401146
xp_assert_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0))
11411147
w = sinc(xp.linspace(-1, 1, 100))
@@ -1147,6 +1153,7 @@ def test_dtype(self, xp: ModuleType, x: int | complex):
11471153
with pytest.raises(ValueError, match="real floating data type"):
11481154
_ = sinc(xp.asarray(x))
11491155

1156+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange")
11501157
def test_3d(self, xp: ModuleType):
11511158
x = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 3, 2))
11521159
expected = xp.zeros((3, 3, 2), dtype=xp.float64)

0 commit comments

Comments
 (0)