From 169dd2b4995ae963b1d42599b27a663a37e16678 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 4 Jun 2025 11:04:30 +0100 Subject: [PATCH 1/3] MAINT: bump to sparse >=0.17 --- src/array_api_extra/_lib/_utils/_helpers.py | 28 +++++++++++---------- tests/test_funcs.py | 21 ++++++++-------- tests/test_helpers.py | 2 +- tests/test_testing.py | 2 +- 4 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 3e43fa9..f5f5cdd 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -322,26 +322,28 @@ def capabilities( dict Capabilities of the namespace. """ - if is_pydata_sparse_namespace(xp): - # No __array_namespace_info__(); no indexing by sparse arrays - return { - "boolean indexing": False, - "data-dependent shapes": True, - "max dimensions": None, - } out = xp.__array_namespace_info__().capabilities() - if is_jax_namespace(xp) and out["boolean indexing"]: - # FIXME https://github.com/jax-ml/jax/issues/27418 - # Fixed in jax >=0.6.0 - out = out.copy() - out["boolean indexing"] = False - if is_torch_namespace(xp): + if is_pydata_sparse_namespace(xp): + if out["boolean indexing"]: + # FIXME https://github.com/pydata/sparse/issues/876 + # boolean indexing is supported, but not when the index is a sparse array. + # boolean indexing by list or numpy array is not part of the Array API. + out = out.copy() + out["boolean indexing"] = False + elif is_jax_namespace(xp): + if out["boolean indexing"]: + # Backwards compatibility with jax <0.6.0 + # https://github.com/jax-ml/jax/issues/27418 + out = out.copy() + out["boolean indexing"] = False + elif is_torch_namespace(xp): # FIXME https://github.com/data-apis/array-api/issues/945 device = xp.get_default_device() if device is None else xp.device(device) if device.type == "meta": # type: ignore[union-attr] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] out = out.copy() out["boolean indexing"] = False out["data-dependent shapes"] = False + return out diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 99da4a2..112cd01 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -416,7 +416,7 @@ def test_complex(self, xp: ModuleType): expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128) xp_assert_close(actual, expect) - @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="matmul with nan fillvalue") + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="matmul with nan fillvalue") def test_empty(self, xp: ModuleType): with warnings.catch_warnings(record=True): warnings.simplefilter("always", RuntimeWarning) @@ -451,7 +451,7 @@ def test_xp(self, xp: ModuleType): ) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="backend doesn't have arange") +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False) class TestOneHot: @pytest.mark.parametrize("n_dim", range(4)) @pytest.mark.parametrize("num_classes", [1, 3, 10]) @@ -816,7 +816,7 @@ def test_bool_dtype(self, xp: ModuleType): isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True]) ) - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array") + @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape") def test_none_shape(self, xp: ModuleType): a = xp.asarray([1, 5, 0]) @@ -825,7 +825,7 @@ def test_none_shape(self, xp: ModuleType): a = a[a < 5] xp_assert_equal(isclose(a, b), xp.asarray([True, False])) - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array") + @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape") def test_none_shape_bool(self, xp: ModuleType): a = xp.asarray([True, True, False]) @@ -1141,10 +1141,10 @@ def test_xp(self, xp: ModuleType): class TestSinc: - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no linspace") def test_simple(self, xp: ModuleType): xp_assert_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0)) - w = sinc(xp.linspace(-1, 1, 100)) + x = xp.asarray(np.linspace(-1, 1, 100)) + w = sinc(x) # check symmetry xp_assert_close(w, xp.flip(w, axis=0)) @@ -1153,11 +1153,12 @@ def test_dtype(self, xp: ModuleType, x: int | complex): with pytest.raises(ValueError, match="real floating data type"): _ = sinc(xp.asarray(x)) - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange") def test_3d(self, xp: ModuleType): - x = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 3, 2)) - expected = xp.zeros((3, 3, 2), dtype=xp.float64) - expected = at(expected)[0, 0, 0].set(1.0) + x = np.arange(18, dtype=np.float64).reshape((3, 3, 2)) + expected = np.zeros_like(x) + expected[0, 0, 0] = 1 + x = xp.asarray(x) + expected = xp.asarray(expected) xp_assert_close(sinc(x), expected, atol=1e-15) def test_device(self, xp: ModuleType, device: Device): diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 74a0ec9..4a545df 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -40,7 +40,7 @@ def override(func): lazy_xp_function(in1d, jax_jit=False) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse") +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_inverse") class TestIn1D: # cover both code paths diff --git a/tests/test_testing.py b/tests/test_testing.py index 43d6e8a..5a123db 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -140,7 +140,7 @@ def test_assert_less(self, xp: ModuleType): xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1])) @pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less]) - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array") + @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing") def test_none_shape(self, xp: ModuleType, func: Callable[..., None]): """On Dask and other lazy backends, test that a shape with NaN's or None's From 42b1c2a44743f15a151fb092a8eb059c02d4d988 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 4 Jun 2025 11:17:32 +0100 Subject: [PATCH 2/3] Bump CI to jax 0.6 --- src/array_api_extra/_lib/_utils/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index f5f5cdd..7851d32 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -331,7 +331,7 @@ def capabilities( out = out.copy() out["boolean indexing"] = False elif is_jax_namespace(xp): - if out["boolean indexing"]: + if out["boolean indexing"]: # pragma: no cover # Backwards compatibility with jax <0.6.0 # https://github.com/jax-ml/jax/issues/27418 out = out.copy() From 69c88d9e15913b0771dc5f2ecba67a4f0598f232 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 4 Jun 2025 14:58:57 +0100 Subject: [PATCH 3/3] link sparse issue --- tests/test_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 112cd01..88fb580 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -416,7 +416,7 @@ def test_complex(self, xp: ModuleType): expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128) xp_assert_close(actual, expect) - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="matmul with nan fillvalue") + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#877") def test_empty(self, xp: ModuleType): with warnings.catch_warnings(record=True): warnings.simplefilter("always", RuntimeWarning)