diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 3e43fa9..7851d32 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -322,26 +322,28 @@ def capabilities( dict Capabilities of the namespace. """ - if is_pydata_sparse_namespace(xp): - # No __array_namespace_info__(); no indexing by sparse arrays - return { - "boolean indexing": False, - "data-dependent shapes": True, - "max dimensions": None, - } out = xp.__array_namespace_info__().capabilities() - if is_jax_namespace(xp) and out["boolean indexing"]: - # FIXME https://github.com/jax-ml/jax/issues/27418 - # Fixed in jax >=0.6.0 - out = out.copy() - out["boolean indexing"] = False - if is_torch_namespace(xp): + if is_pydata_sparse_namespace(xp): + if out["boolean indexing"]: + # FIXME https://github.com/pydata/sparse/issues/876 + # boolean indexing is supported, but not when the index is a sparse array. + # boolean indexing by list or numpy array is not part of the Array API. + out = out.copy() + out["boolean indexing"] = False + elif is_jax_namespace(xp): + if out["boolean indexing"]: # pragma: no cover + # Backwards compatibility with jax <0.6.0 + # https://github.com/jax-ml/jax/issues/27418 + out = out.copy() + out["boolean indexing"] = False + elif is_torch_namespace(xp): # FIXME https://github.com/data-apis/array-api/issues/945 device = xp.get_default_device() if device is None else xp.device(device) if device.type == "meta": # type: ignore[union-attr] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] out = out.copy() out["boolean indexing"] = False out["data-dependent shapes"] = False + return out diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 99da4a2..88fb580 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -416,7 +416,7 @@ def test_complex(self, xp: ModuleType): expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128) xp_assert_close(actual, expect) - @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="matmul with nan fillvalue") + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#877") def test_empty(self, xp: ModuleType): with warnings.catch_warnings(record=True): warnings.simplefilter("always", RuntimeWarning) @@ -451,7 +451,7 @@ def test_xp(self, xp: ModuleType): ) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="backend doesn't have arange") +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False) class TestOneHot: @pytest.mark.parametrize("n_dim", range(4)) @pytest.mark.parametrize("num_classes", [1, 3, 10]) @@ -816,7 +816,7 @@ def test_bool_dtype(self, xp: ModuleType): isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True]) ) - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array") + @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape") def test_none_shape(self, xp: ModuleType): a = xp.asarray([1, 5, 0]) @@ -825,7 +825,7 @@ def test_none_shape(self, xp: ModuleType): a = a[a < 5] xp_assert_equal(isclose(a, b), xp.asarray([True, False])) - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array") + @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape") def test_none_shape_bool(self, xp: ModuleType): a = xp.asarray([True, True, False]) @@ -1141,10 +1141,10 @@ def test_xp(self, xp: ModuleType): class TestSinc: - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no linspace") def test_simple(self, xp: ModuleType): xp_assert_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0)) - w = sinc(xp.linspace(-1, 1, 100)) + x = xp.asarray(np.linspace(-1, 1, 100)) + w = sinc(x) # check symmetry xp_assert_close(w, xp.flip(w, axis=0)) @@ -1153,11 +1153,12 @@ def test_dtype(self, xp: ModuleType, x: int | complex): with pytest.raises(ValueError, match="real floating data type"): _ = sinc(xp.asarray(x)) - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange") def test_3d(self, xp: ModuleType): - x = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 3, 2)) - expected = xp.zeros((3, 3, 2), dtype=xp.float64) - expected = at(expected)[0, 0, 0].set(1.0) + x = np.arange(18, dtype=np.float64).reshape((3, 3, 2)) + expected = np.zeros_like(x) + expected[0, 0, 0] = 1 + x = xp.asarray(x) + expected = xp.asarray(expected) xp_assert_close(sinc(x), expected, atol=1e-15) def test_device(self, xp: ModuleType, device: Device): diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 74a0ec9..4a545df 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -40,7 +40,7 @@ def override(func): lazy_xp_function(in1d, jax_jit=False) -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse") +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_inverse") class TestIn1D: # cover both code paths diff --git a/tests/test_testing.py b/tests/test_testing.py index 43d6e8a..5a123db 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -140,7 +140,7 @@ def test_assert_less(self, xp: ModuleType): xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1])) @pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less]) - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array") + @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing") def test_none_shape(self, xp: ModuleType, func: Callable[..., None]): """On Dask and other lazy backends, test that a shape with NaN's or None's