From 428ee84917f1fc1051becb7b25625142698f6197 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 5 Dec 2024 22:02:25 +0000 Subject: [PATCH 1/3] Support changed asarray behaviour in dask 2023.12.0 --- array_api_compat/dask/array/_aliases.py | 29 ++++++++-------- tests/test_common.py | 44 +++++++++++++++++-------- 2 files changed, 46 insertions(+), 27 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index ee2d88c0..b8a72035 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -128,23 +128,26 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. + + .. note:: + copy=True means that if you update the output array the input will never + be affected; however the output array may internally hold references to the + input array, preventing deallocation. This kind of implementation detail should + be left at dask's discretion. """ if copy is False: # copy=False is not yet implemented in dask - raise NotImplementedError("copy=False is not yet implemented") - elif copy is True: - if isinstance(obj, da.Array) and dtype is None: - return obj.copy() - # Go through numpy, since dask copy is no-op by default - obj = np.array(obj, dtype=dtype, copy=True) - return da.array(obj, dtype=dtype) - else: - if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype: - obj = np.asarray(obj, dtype=dtype) - return da.from_array(obj) - return obj + raise NotImplementedError("copy=False can't be implemented in dask") + + if ( + copy is True + and isinstance(obj, da.Array) + and (dtype is None or dtype == obj.dtype) + ): + return obj.copy() + + return da.asarray(obj, dtype=dtype) - return da.asarray(obj, dtype=dtype, **kwargs) from dask.array import ( # Element wise aliases diff --git a/tests/test_common.py b/tests/test_common.py index e1cfa9eb..7e5e077d 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -112,6 +112,7 @@ def test_asarray_cross_library(source_library, target_library, request): assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" + @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): # Note, we have this test here because the test suite currently doesn't @@ -130,6 +131,7 @@ def test_asarray_copy(library): else: supports_copy_false = True + # Tests for copy=True a = asarray([1]) b = asarray(a, copy=True) assert is_lib_func(b) @@ -137,6 +139,14 @@ def test_asarray_copy(library): assert all(b[0] == 1) assert all(a[0] == 0) + a = asarray([1]) + b = asarray(a, copy=True, dtype=a.dtype) + assert is_lib_func(b) + a[0] = 0 + assert all(b[0] == 1) + assert all(a[0] == 0) + + # Tests for copy=False a = asarray([1]) if supports_copy_false: b = asarray(a, copy=False) @@ -144,20 +154,26 @@ def test_asarray_copy(library): a[0] = 0 assert all(b[0] == 0) else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) + with pytest.raises(NotImplementedError): + asarray(a, copy=False) a = asarray([1]) if supports_copy_false: - pytest.raises(ValueError, lambda: asarray(a, copy=False, - dtype=xp.float64)) + with pytest.raises(ValueError): + asarray(a, copy=False, dtype=xp.float64) else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64)) + with pytest.raises(NotImplementedError): + asarray(a, copy=False, dtype=xp.float64) + # Tests for copy=None + # Do not test whether the buffer is shared or not after copy=None. + # A library should have the freedom to alter its behaviour + # without treating it as a breaking change. a = asarray([1]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 0) + assert all((b[0] == 1.0) | (b[0] == 0.0)) a = asarray([1.0], dtype=xp.float32) assert a.dtype == xp.float32 @@ -165,6 +181,7 @@ def test_asarray_copy(library): assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 + # dtype change must always trigger a copy assert all(b[0] == 1.0) a = asarray([1.0], dtype=xp.float64) @@ -173,16 +190,18 @@ def test_asarray_copy(library): assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 0.0) + assert all((b[0] == 1.0) | (b[0] == 0.0)) # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: asarray(obj, copy=True) # No error asarray(obj, copy=None) # No error if supports_copy_false: - pytest.raises(ValueError, lambda: asarray(obj, copy=False)) + with pytest.raises(ValueError): + asarray(obj, copy=False) else: - pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False)) + with pytest.raises(NotImplementedError): + asarray(obj, copy=False) # Use the standard library array to test the buffer protocol a = array.array('f', [1.0]) @@ -198,14 +217,11 @@ def test_asarray_copy(library): a[0] = 0.0 assert all(b[0] == 0.0) else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) + with pytest.raises(NotImplementedError): + asarray(a, copy=False) a = array.array('f', [1.0]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0.0 - if library == 'cupy': - # A copy is required for libraries where the default device is not CPU - assert all(b[0] == 1.0) - else: - assert all(b[0] == 0.0) + assert all((b[0] == 1.0) | (b[0] == 0.0)) From 9b1d449de5d843e834beb3e9ff2a952e0737c3d8 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 5 Dec 2024 22:31:17 +0000 Subject: [PATCH 2/3] xfail --- tests/test_common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_common.py b/tests/test_common.py index 7e5e077d..199dbc3b 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -91,7 +91,10 @@ def test_to_device_host(library): @pytest.mark.parametrize("target_library", is_array_functions.keys()) @pytest.mark.parametrize("source_library", is_array_functions.keys()) def test_asarray_cross_library(source_library, target_library, request): - if source_library == "dask.array" and target_library == "torch": + if ( + (source_library == "dask.array" and target_library == "torch") + or (source_library == "torch" and target_library == "dask.array") + ): # Allow rest of test to execute instead of immediately xfailing # xref https://github.com/pandas-dev/pandas/issues/38902 From b570a8424b1fec17a1185dc5c0e9fa130abcde1d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 5 Dec 2024 22:41:29 +0000 Subject: [PATCH 3/3] Fix backwards compat --- array_api_compat/dask/array/_aliases.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index b8a72035..4371f769 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -146,7 +146,13 @@ def asarray( ): return obj.copy() - return da.asarray(obj, dtype=dtype) + obj = da.asarray(obj, dtype=dtype) + + # Backport https://github.com/dask/dask/pull/11586 + if dtype not in (None, obj.dtype): + obj = obj.astype(dtype) + + return obj from dask.array import (