From fd7a79969c56f6a778744121b444de1c06a34a1f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sun, 18 May 2025 13:24:49 +0100 Subject: [PATCH 1/4] ENH: support PyTorch `device='meta'` --- src/array_api_extra/_lib/_funcs.py | 4 +-- src/array_api_extra/_lib/_testing.py | 6 ++++- src/array_api_extra/_lib/_utils/_helpers.py | 14 ++++++++-- tests/conftest.py | 8 +++--- tests/test_funcs.py | 6 ++--- tests/test_helpers.py | 30 +++++++++++++++++---- tests/test_lazy.py | 3 +++ tests/test_testing.py | 20 +++++++++++--- 8 files changed, 71 insertions(+), 20 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index bb39775f..645f7a1b 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -153,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01 ) -> Array: """Helper of `apply_where`. On Dask, this runs on a single chunk.""" - if not capabilities(xp)["boolean indexing"]: + if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]: # jax.jit does not support assignment by boolean mask return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value) @@ -716,7 +716,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: # 2. backend has unique_counts and it returns a None-sized array; # e.g. Dask, ndonnx # 3. backend does not have unique_counts; e.g. wrapped JAX - if capabilities(xp)["data-dependent shapes"]: + if capabilities(xp, device=_compat.device(x))["data-dependent shapes"]: # xp has unique_counts; O(n) complexity _, counts = xp.unique_counts(x) n = _compat.size(counts) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index c60cf466..e0535fbd 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -100,7 +100,11 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] if is_torch_namespace(xp): - array = to_device(array, "cpu") + if array.device.type == "meta": # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + # Can't materialize; generate dummy data instead + array = xp.zeros_like(array, device="cpu") + else: + array = to_device(array, "cpu") if is_array_api_strict_namespace(xp): cpu: Device = xp.Device("CPU_DEVICE") array = to_device(array, cpu) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index b856eb41..cf14873f 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -29,8 +29,9 @@ is_jax_namespace, is_numpy_array, is_pydata_sparse_namespace, + is_torch_namespace, ) -from ._typing import Array +from ._typing import Array, Device if TYPE_CHECKING: # pragma: no cover # TODO import from typing (requires Python >=3.12 and >=3.13) @@ -300,7 +301,7 @@ def meta_namespace( return array_namespace(*metas) -def capabilities(xp: ModuleType) -> dict[str, int]: +def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, int]: """ Return patched ``xp.__array_namespace_info__().capabilities()``. @@ -311,6 +312,8 @@ def capabilities(xp: ModuleType) -> dict[str, int]: ---------- xp : array_namespace The standard-compatible namespace. + device : Device, optional + The device to use. Returns ------- @@ -326,6 +329,13 @@ def capabilities(xp: ModuleType) -> dict[str, int]: # Fixed in jax >=0.6.0 out = out.copy() out["boolean indexing"] = False + if is_torch_namespace(xp): + # FIXME https://github.com/data-apis/array-api/issues/945 + device = xp.get_default_device() if device is None else xp.device(device) + if cast(Any, device).type == "meta": # type: ignore[explicit-any] + out = out.copy() + out["boolean indexing"] = False + out["data-dependent shapes"] = False return out diff --git a/tests/conftest.py b/tests/conftest.py index 372f9960..bb72139b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -211,7 +211,9 @@ def device( Where possible, return a device that is not the default one. """ if library == Backend.ARRAY_API_STRICT: - d = xp.Device("device1") - assert get_device(xp.empty(0)) != d - return d + return xp.Device("device1") + if library == Backend.TORCH: + return xp.device("meta") + if library == Backend.TORCH_GPU: + return xp.device("cpu") return get_device(xp.empty(0)) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 4cd1718c..7c138442 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -731,9 +731,6 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): b = xp.asarray([1e-9, 1e-4, xp.nan], device=device) res = isclose(a, b, equal_nan=equal_nan) assert get_device(res) == device - xp_assert_equal( - isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan]) - ) class TestKron: @@ -996,6 +993,9 @@ def test_all_python_scalars(self, assume_unique: bool): _ = setdiff1d(0, 0, assume_unique=assume_unique) @assume_unique + @pytest.mark.skip_xp_backend( + Backend.TORCH, reason="device='meta' does not support unknown shapes" + ) def test_device(self, xp: ModuleType, device: Device, assume_unique: bool): x1 = xp.asarray([3, 8, 20], device=device) x2 = xp.asarray([2, 3, 4], device=device) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index d2068f13..b1b6ddba 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -212,11 +212,31 @@ def test_xp(self, xp: ModuleType): assert meta_namespace(*args, xp=xp) in (xp, np_compat) -def test_capabilities(xp: ModuleType): - expect = {"boolean indexing", "data-dependent shapes"} - if xp.__array_api_version__ >= "2024.12": - expect.add("max dimensions") - assert capabilities(xp).keys() == expect +class TestCapabilities: + def test_basic(self, xp: ModuleType): + expect = {"boolean indexing", "data-dependent shapes"} + if xp.__array_api_version__ >= "2024.12": + expect.add("max dimensions") + assert capabilities(xp).keys() == expect + + def test_device(self, xp: ModuleType, library: Backend, device: Device): + expect_keys = {"boolean indexing", "data-dependent shapes"} + if xp.__array_api_version__ >= "2024.12": + expect_keys.add("max dimensions") + assert capabilities(xp, device=device).keys() == expect_keys + + if library.like(Backend.TORCH): + # The output of capabilities is device-specific. + + # Test that device=None gets the current default device. + expect = capabilities(xp, device=device) + with xp.device(device): + actual = capabilities(xp) + assert actual == expect + + # Test that we're accepting anything that is accepted by the + # device= parameter in other functions + actual = capabilities(xp, device=device.type) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue] class Wrapper(Generic[T]): diff --git a/tests/test_lazy.py b/tests/test_lazy.py index aef73301..f97062f0 100644 --- a/tests/test_lazy.py +++ b/tests/test_lazy.py @@ -278,6 +278,9 @@ def test_lazy_apply_none_shape_broadcast(xp: ModuleType): Backend.ARRAY_API_STRICT, reason="device->host copy" ), pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"), + pytest.mark.skip_xp_backend( + Backend.TORCH, reason="materialize 'meta' device" + ), pytest.mark.skip_xp_backend( Backend.TORCH_GPU, reason="device->host copy" ), diff --git a/tests/test_testing.py b/tests/test_testing.py index 51a7775e..819ee7ac 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -39,10 +39,22 @@ ) -def test_as_numpy_array(xp: ModuleType, device: Device): - x = xp.asarray([1, 2, 3], device=device) - y = as_numpy_array(x, xp=xp) - assert isinstance(y, np.ndarray) +class TestAsNumPyArray: + def test_basic(self, xp: ModuleType): + x = xp.asarray([1, 2, 3]) + y = as_numpy_array(x, xp=xp) + xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + + def test_device(self, xp: ModuleType, library: Backend, device: Device): + x = xp.asarray([1, 2, 3], device=device) + actual = as_numpy_array(x, xp=xp) + if library is Backend.TORCH: + assert device.type == "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + expect = np.asarray([0, 0, 0]) + else: + expect = np.asarray([1, 2, 3]) + + xp_assert_equal(actual, expect) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False) From 90c3aec415251364261e1bcb66b330ecde8bd3de Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sun, 18 May 2025 23:25:07 +0100 Subject: [PATCH 2/4] don't cast device to Any --- src/array_api_extra/_lib/_utils/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index cf14873f..f4997dd3 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -332,7 +332,7 @@ def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, i if is_torch_namespace(xp): # FIXME https://github.com/data-apis/array-api/issues/945 device = xp.get_default_device() if device is None else xp.device(device) - if cast(Any, device).type == "meta": # type: ignore[explicit-any] + if device.type == "meta": # type: ignore[union-attr] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] out = out.copy() out["boolean indexing"] = False out["data-dependent shapes"] = False From ae0bf1f69ddbb87c0711ba1d2258b9749f6dfd74 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 19 May 2025 00:01:23 +0100 Subject: [PATCH 3/4] Fix xp_array_less --- src/array_api_extra/_lib/_testing.py | 52 ++++++++++++++++++++-------- tests/test_testing.py | 29 ++++++++++------ 2 files changed, 57 insertions(+), 24 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index e0535fbd..38cd612b 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -22,6 +22,7 @@ is_jax_namespace, is_numpy_namespace, is_pydata_sparse_namespace, + is_torch_array, is_torch_namespace, to_device, ) @@ -62,18 +63,28 @@ def _check_ns_shape_dtype( msg = f"namespaces do not match: {actual_xp} != f{desired_xp}" assert actual_xp == desired_xp, msg - if check_shape: - actual_shape = actual.shape - desired_shape = desired.shape - if is_dask_namespace(desired_xp): - # Dask uses nan instead of None for unknown shapes - if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)): - actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)): - desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + # Dask uses nan instead of None for unknown shapes + actual_shape = cast(tuple[float, ...], actual.shape) + desired_shape = cast(tuple[float, ...], desired.shape) + assert None not in actual_shape # Requires explicit support + assert None not in desired_shape + if is_dask_namespace(desired_xp): + if any(math.isnan(i) for i in actual_shape): + actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + if any(math.isnan(i) for i in desired_shape): + desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + if check_shape: msg = f"shapes do not match: {actual_shape} != f{desired_shape}" assert actual_shape == desired_shape, msg + else: + # Ignore shape, but check flattened size. This is normally done by + # np.testing.assert_array_equal etc even when strict=False, but not for + # non-materializable arrays. + actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType] + desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType] + msg = f"sizes do not match: {actual_size} != f{desired_size}" + assert actual_size == desired_size, msg if check_dtype: msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}" @@ -90,6 +101,17 @@ def _check_ns_shape_dtype( return desired_xp +def _is_materializable(x: Array) -> bool: + """ + Check if the array is materializable, e.g. `as_numpy_array` can be called on it + and one can assume that `__dlpack__` will succeed (if implemented, and given a + compatible device). + """ + # Important: here we assume that we're not tracing - + # e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`. + return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + + def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any] """ Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards. @@ -100,11 +122,7 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] if is_torch_namespace(xp): - if array.device.type == "meta": # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - # Can't materialize; generate dummy data instead - array = xp.zeros_like(array, device="cpu") - else: - array = to_device(array, "cpu") + array = to_device(array, "cpu") if is_array_api_strict_namespace(xp): cpu: Device = xp.Device("CPU_DEVICE") array = to_device(array, cpu) @@ -150,6 +168,8 @@ def xp_assert_equal( numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + if not _is_materializable(actual): + return actual_np = as_numpy_array(actual, xp=xp) desired_np = as_numpy_array(desired, xp=xp) np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) @@ -185,6 +205,8 @@ def xp_assert_less( numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar) + if not _is_materializable(x): + return x_np = as_numpy_array(x, xp=xp) y_np = as_numpy_array(y, xp=xp) np.testing.assert_array_less(x_np, y_np, err_msg=err_msg) @@ -233,6 +255,8 @@ def xp_assert_close( The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`. """ xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + if not _is_materializable(actual): + return if rtol is None: if xp.isdtype(actual.dtype, ("real floating", "complex floating")): diff --git a/tests/test_testing.py b/tests/test_testing.py index bda88911..b195a65b 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -30,16 +30,11 @@ def test_basic(self, xp: ModuleType): y = as_numpy_array(x, xp=xp) xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - def test_device(self, xp: ModuleType, library: Backend, device: Device): + @pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device") + def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3], device=device) - actual = as_numpy_array(x, xp=xp) - if library is Backend.TORCH: - assert device.type == "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - expect = np.asarray([0, 0, 0]) - else: - expect = np.asarray([1, 2, 3]) - - xp_assert_equal(actual, expect) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + y = as_numpy_array(x, xp=xp) + xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] class TestAssertEqualCloseLess: @@ -92,7 +87,7 @@ def test_check_shape(self, xp: ModuleType, func: Callable[..., None]): func(a, b, check_shape=False) with pytest.raises(AssertionError, match="Mismatched elements"): func(a, c, check_shape=False) - with pytest.raises(AssertionError, match=r"shapes \(1,\), \(2,\) mismatch"): + with pytest.raises(AssertionError, match="sizes do not match"): func(a, d, check_shape=False) @pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less]) @@ -181,6 +176,20 @@ def test_none_shape(self, xp: ModuleType, func: Callable[..., None]): with pytest.raises(AssertionError, match="Mismatched elements"): func(xp.asarray([4]), a) + @pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less]) + def test_device(self, xp: ModuleType, device: Device, func: Callable[..., None]): + a = xp.asarray([1] if func is xp_assert_less else [2], device=device) + b = xp.asarray([2], device=device) + c = xp.asarray([2, 2], device=device) + + func(a, b) + with pytest.raises(AssertionError, match="shapes do not match"): + func(a, c) + # This is normally performed by np.testing.assert_array_equal etc. + # but in case of torch device='meta' we have to do it manually + with pytest.raises(AssertionError, match="sizes do not match"): + func(a, c, check_shape=False) + def good_lazy(x: Array) -> Array: """A function that behaves well in Dask and jax.jit""" From baf688066ea6ddbda54f33c25319a485b7cc5226 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 19 May 2025 00:03:34 +0100 Subject: [PATCH 4/4] nit --- src/array_api_extra/_lib/_testing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 38cd612b..16a9d102 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -103,9 +103,7 @@ def _check_ns_shape_dtype( def _is_materializable(x: Array) -> bool: """ - Check if the array is materializable, e.g. `as_numpy_array` can be called on it - and one can assume that `__dlpack__` will succeed (if implemented, and given a - compatible device). + Return True if you can call `as_numpy_array(x)`; False otherwise. """ # Important: here we assume that we're not tracing - # e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`.