Skip to content

Commit f6b5ea2

Browse files
committed
Add dask.array specific implementation of asarray()
This also fixes several bugs in the implementation, and updates some tests.
1 parent 11d27dd commit f6b5ea2

File tree

3 files changed

+51
-96
lines changed

3 files changed

+51
-96
lines changed

array_api_compat/common/_aliases.py

Lines changed: 3 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
9-
import numpy as np
109
from typing import Optional, Sequence, Tuple, Union
1110
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
1211

1312
from typing import NamedTuple
14-
from types import ModuleType
1513
import inspect
1614

17-
from ._helpers import _check_device, is_numpy_array, array_namespace
15+
from ._helpers import _check_device
1816

1917
# These functions are modified from the NumPy versions.
2018

19+
# Creation functions add the device keyword (which does nothing for NumPy)
20+
2121
def arange(
2222
start: Union[int, float],
2323
/,
@@ -268,92 +268,6 @@ def var(
268268
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
269269
return xp.transpose(x, axes)
270270

271-
# Creation functions add the device keyword (which does nothing for NumPy)
272-
273-
# asarray also adds the copy keyword
274-
def _asarray(
275-
obj: Union[
276-
ndarray,
277-
bool,
278-
int,
279-
float,
280-
NestedSequence[bool | int | float],
281-
SupportsBufferProtocol,
282-
],
283-
/,
284-
*,
285-
dtype: Optional[Dtype] = None,
286-
device: Optional[Device] = None,
287-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
288-
namespace = None,
289-
**kwargs,
290-
) -> ndarray:
291-
"""
292-
Array API compatibility wrapper for asarray().
293-
294-
See the corresponding documentation in the array library and/or the array API
295-
specification for more details.
296-
297-
'namespace' may be an array module namespace. This is needed to support
298-
conversion of sequences of Python scalars.
299-
"""
300-
if namespace is None:
301-
try:
302-
xp = array_namespace(obj, _use_compat=False)
303-
except ValueError:
304-
# TODO: What about lists of arrays?
305-
raise ValueError("A namespace must be specified for asarray() with non-array input")
306-
elif isinstance(namespace, ModuleType):
307-
xp = namespace
308-
elif namespace == 'numpy':
309-
import numpy as xp
310-
elif namespace == 'cupy':
311-
import cupy as xp
312-
elif namespace == 'dask.array':
313-
import dask.array as xp
314-
else:
315-
raise ValueError("Unrecognized namespace argument to asarray()")
316-
317-
_check_device(xp, device)
318-
if is_numpy_array(obj):
319-
import numpy as np
320-
if hasattr(np, '_CopyMode'):
321-
# Not present in older NumPys
322-
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
323-
COPY_TRUE = (True, np._CopyMode.ALWAYS)
324-
else:
325-
COPY_FALSE = (False,)
326-
COPY_TRUE = (True,)
327-
else:
328-
COPY_FALSE = (False,)
329-
COPY_TRUE = (True,)
330-
if copy in COPY_FALSE and namespace != "dask.array":
331-
# copy=False is not yet implemented in xp.asarray
332-
raise NotImplementedError("copy=False is not yet implemented")
333-
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
334-
if dtype is not None and obj.dtype != dtype:
335-
copy = True
336-
if copy in COPY_TRUE:
337-
return xp.array(obj, copy=True, dtype=dtype)
338-
return obj
339-
elif namespace == "dask.array":
340-
if copy in COPY_TRUE:
341-
if dtype is None:
342-
return obj.copy()
343-
# Go through numpy, since dask copy is no-op by default
344-
import numpy as np
345-
obj = np.array(obj, dtype=dtype, copy=True)
346-
return xp.array(obj, dtype=dtype)
347-
else:
348-
import dask.array as da
349-
import numpy as np
350-
if not isinstance(obj, da.Array):
351-
obj = np.asarray(obj, dtype=dtype)
352-
return da.from_array(obj)
353-
return obj
354-
355-
return xp.asarray(obj, dtype=dtype, **kwargs)
356-
357271
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
358272
def reshape(x: ndarray,
359273
/,

array_api_compat/dask/array/_aliases.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
if TYPE_CHECKING:
3838
from typing import Optional, Union
3939

40-
from ...common._typing import Device, Dtype, Array
40+
from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol
4141

4242
import dask.array as da
4343

@@ -76,10 +76,6 @@ def _dask_arange(
7676
arange = get_xp(da)(_dask_arange)
7777
eye = get_xp(da)(_aliases.eye)
7878

79-
from functools import partial
80-
asarray = partial(_aliases._asarray, namespace='dask.array')
81-
asarray.__doc__ = _aliases._asarray.__doc__
82-
8379
linspace = get_xp(da)(_aliases.linspace)
8480
eye = get_xp(da)(_aliases.eye)
8581
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
@@ -113,6 +109,47 @@ def _dask_arange(
113109
matmul = get_xp(np)(_aliases.matmul)
114110
tensordot = get_xp(np)(_aliases.tensordot)
115111

112+
113+
# asarray also adds the copy keyword, which is not present in numpy 1.0.
114+
def asarray(
115+
obj: Union[
116+
Array,
117+
bool,
118+
int,
119+
float,
120+
NestedSequence[bool | int | float],
121+
SupportsBufferProtocol,
122+
],
123+
/,
124+
*,
125+
dtype: Optional[Dtype] = None,
126+
device: Optional[Device] = None,
127+
copy: "Optional[Union[bool, np._CopyMode]]" = None,
128+
**kwargs,
129+
) -> Array:
130+
"""
131+
Array API compatibility wrapper for asarray().
132+
133+
See the corresponding documentation in the array library and/or the array API
134+
specification for more details.
135+
"""
136+
if copy is False:
137+
# copy=False is not yet implemented in dask
138+
raise NotImplementedError("copy=False is not yet implemented")
139+
elif copy is True:
140+
if isinstance(obj, da.Array) and dtype is None:
141+
return obj.copy()
142+
# Go through numpy, since dask copy is no-op by default
143+
obj = np.array(obj, dtype=dtype, copy=True)
144+
return da.array(obj, dtype=dtype)
145+
else:
146+
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
147+
obj = np.asarray(obj, dtype=dtype)
148+
return da.from_array(obj)
149+
return obj
150+
151+
return da.asarray(obj, dtype=dtype, **kwargs)
152+
116153
from dask.array import (
117154
# Element wise aliases
118155
arccos as acos,

tests/test_common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ def test_asarray_copy(library):
9494
xp = import_(library, wrapper=True)
9595
asarray = xp.asarray
9696
is_lib_func = globals()[is_functions[library]]
97-
all = xp.all
97+
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
9898

9999
if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
100100
supports_copy_false = False
101-
elif library == 'cupy':
101+
elif library in ['cupy', 'dask.array']:
102102
supports_copy_false = False
103103
else:
104104
supports_copy_false = True
@@ -133,14 +133,18 @@ def test_asarray_copy(library):
133133
assert all(b[0] == 0)
134134

135135
a = asarray([1.0], dtype=xp.float32)
136+
assert a.dtype == xp.float32
136137
b = asarray(a, dtype=xp.float64, copy=None)
137138
assert is_lib_func(b)
139+
assert b.dtype == xp.float64
138140
a[0] = 0.0
139141
assert all(b[0] == 1.0)
140142

141143
a = asarray([1.0], dtype=xp.float64)
144+
assert a.dtype == xp.float64
142145
b = asarray(a, dtype=xp.float64, copy=None)
143146
assert is_lib_func(b)
147+
assert b.dtype == xp.float64
144148
a[0] = 0.0
145149
assert all(b[0] == 0.0)
146150

0 commit comments

Comments
 (0)