19
19
from typing import ParamSpec , TypeAlias
20
20
21
21
import numpy as np
22
+ from numpy .typing import ArrayLike
22
23
23
24
NumPyObject : TypeAlias = np .ndarray [Any , Any ] | np .generic # type: ignore[no-any-explicit]
24
25
P = ParamSpec ("P" )
@@ -32,58 +33,74 @@ class P: # pylint: disable=missing-class-docstring
32
33
33
34
34
35
@overload
35
- def apply_numpy_func ( # type: ignore[valid-type]
36
- func : Callable [P , NumPyObject ],
36
+ def apply_lazy ( # type: ignore[valid-type]
37
+ func : Callable [P , ArrayLike ],
37
38
* args : Array ,
38
39
shape : tuple [int | None , ...] | None = None ,
39
40
dtype : DType | None = None ,
41
+ as_numpy : bool = False ,
40
42
xp : ModuleType | None = None ,
41
43
** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
42
44
) -> Array : ... # numpydoc ignore=GL08
43
45
44
46
45
47
@overload
46
- def apply_numpy_func ( # type: ignore[valid-type]
47
- func : Callable [P , Sequence [NumPyObject ]],
48
+ def apply_lazy ( # type: ignore[valid-type]
49
+ func : Callable [P , Sequence [ArrayLike ]],
48
50
* args : Array ,
49
51
shape : Sequence [tuple [int | None , ...]],
50
52
dtype : Sequence [DType ] | None = None ,
53
+ as_numpy : bool = False ,
51
54
xp : ModuleType | None = None ,
52
55
** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
53
56
) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
54
57
55
58
56
- def apply_numpy_func ( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
57
- func : Callable [P , NumPyObject | Sequence [NumPyObject ]],
59
+ def apply_lazy ( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
60
+ func : Callable [P , Array | Sequence [ArrayLike ]],
58
61
* args : Array ,
59
62
shape : tuple [int | None , ...] | Sequence [tuple [int | None , ...]] | None = None ,
60
63
dtype : DType | Sequence [DType ] | None = None ,
64
+ as_numpy : bool = False ,
61
65
xp : ModuleType | None = None ,
62
66
** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
63
67
) -> Array | tuple [Array , ...]:
64
68
"""
65
- Apply a function that operates on NumPy arrays to Array API compliant arrays.
69
+ Lazily apply an eager function.
70
+
71
+ If the backend of the input arrays is lazy, e.g. Dask or jitted JAX, the execution
72
+ of the function is delayed until the graph is materialized; if it's eager, the
73
+ function is executed immediately.
66
74
67
75
Parameters
68
76
----------
69
77
func : callable
70
- The function to apply. It must accept one or more NumPy arrays or generics as
71
- positional arguments and return either a single NumPy array or generic, or a
72
- tuple or list thereof.
78
+ The function to apply.
79
+
80
+ It must accept one or more array API compliant arrays as positional arguments.
81
+ If `as_numpy=True`, inputs are converted to NumPy before they are passed to
82
+ `func`.
83
+ It must return either a single array-like or a sequence of array-likes.
73
84
74
- It must be a pure function, i.e. without side effects, as depending on the
85
+ `func` must be a pure function, i.e. without side effects, as depending on the
75
86
backend it may be executed more than once.
76
87
*args : Array
77
- One or more Array API compliant arrays. You need to be able to apply
78
- :func:`numpy.asarray` to them to convert them to numpy; read notes below about
79
- specific backends.
88
+ One or more Array API compliant arrays.
89
+
90
+ If `as_numpy=True`, you need to be able to apply :func:`numpy.asarray` to them
91
+ to convert them to numpy; read notes below about specific backends.
80
92
shape : tuple[int | None, ...] | Sequence[tuple[int, ...]], optional
81
93
Output shape or sequence of output shapes, one for each output of `func`.
82
94
Default: assume single output and broadcast shapes of the input arrays.
83
95
dtype : DType | Sequence[DType], optional
84
96
Output dtype or sequence of output dtypes, one for each output of `func`.
85
97
dtype(s) must belong to the same array namespace as the input arrays.
86
98
Default: infer the result type(s) from the input arrays.
99
+ as_numpy : bool, optional
100
+ If True, convert the input arrays to NumPy before passing them to `func`.
101
+ This is particularly useful to make numpy-only functions, e.g. written in Cython
102
+ or Numba, work transparently API arrays.
103
+ Default: False.
87
104
xp : array_namespace, optional
88
105
The standard-compatible namespace for `args`. Default: infer.
89
106
**kwargs : Any, optional
@@ -95,7 +112,7 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
95
112
Array | tuple[Array, ...]
96
113
The result(s) of `func` applied to the input arrays, wrapped in the same
97
114
array namespace as the inputs.
98
- If shape is omitted or a `tuple[int, ...]`, this is a single array.
115
+ If shape is omitted or a `tuple[int | None , ...]`, this is a single array.
99
116
Otherwise, it's a tuple of arrays.
100
117
101
118
Notes
@@ -106,23 +123,26 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
106
123
When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot
107
124
contain any `None` elements.
108
125
109
- The :doc:`jax:transfer_guard` may prevent arrays on a GPU device from being
110
- transferred back to CPU. This is treated as an implicit transfer.
126
+ Using this with `as_numpy=False` is particularly useful to apply non-jittable
127
+ JAX functions to arrays on GPU devices.
128
+ If `as_numpy=True`, the :doc:`jax:transfer_guard` may prevent arrays on a GPU
129
+ device from being transferred back to CPU. This is treated as an implicit
130
+ transfer.
111
131
112
132
PyTorch, CuPy
113
- These backends raise by default if you attempt to convert arrays on a GPU device
114
- to NumPy.
133
+ If `as_numpy=True`, these backends raise by default if you attempt to convert
134
+ arrays on a GPU device to NumPy.
115
135
116
136
Sparse
117
- By default, sparse prevents implicit densification through
137
+ If `as_numpy=True`, by default sparse prevents implicit densification through
118
138
:func:`numpy.asarray`. `This safety mechanism can be disabled
119
139
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
120
140
121
141
Dask
122
142
This allows applying eager functions to dask arrays.
123
143
The dask graph won't be computed.
124
144
125
- `apply_numpy_func ` doesn't know if `func` reduces along any axes; also, shape
145
+ `apply_lazy ` doesn't know if `func` reduces along any axes; also, shape
126
146
changes are non-trivial in chunked Dask arrays. For these reasons, all inputs
127
147
will be rechunked into a single chunk.
128
148
@@ -136,7 +156,13 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
136
156
If you want to distribute the calculation across multiple workers, you
137
157
should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
138
158
:func:`dask.array.blockwise`, or a native Dask wrapper instead of
139
- `apply_numpy_func`.
159
+ `apply_lazy`.
160
+
161
+ Dask wrapping around other backends
162
+ If `as_numpy=False`, `func` will receive in input eager arrays of the meta
163
+ namespace, as defined by the `._meta` attribute of the input Dask arrays.
164
+ The outputs of `func` will be wrapped by the meta namespace, and then wrapped
165
+ again by Dask.
140
166
141
167
Raises
142
168
------
@@ -202,7 +228,10 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
202
228
metas = [arg ._meta for arg in args if hasattr (arg , "_meta" )] # pylint: disable=protected-access
203
229
meta_xp = array_namespace (* metas )
204
230
205
- wrapped = dask .delayed (_npfunc_wrapper (func , multi_output , meta_xp ), pure = True )
231
+ wrapped = dask .delayed (
232
+ _apply_lazy_wrapper (func , as_numpy , multi_output , meta_xp ),
233
+ pure = True ,
234
+ )
206
235
# This finalizes each arg, which is the same as arg.rechunk(-1).
207
236
# Please read docstring above for why we're not using
208
237
# dask.array.map_blocks or dask.array.blockwise!
@@ -227,7 +256,7 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
227
256
228
257
import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
229
258
230
- wrapped = _npfunc_wrapper (func , multi_output , xp )
259
+ wrapped = _apply_lazy_wrapper (func , as_numpy , multi_output , xp )
231
260
232
261
if any (s is None for shape in shapes for s in shape ):
233
262
# Unknown output shape. Won't work with jax.jit, but it
@@ -251,19 +280,20 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
251
280
252
281
else :
253
282
# Eager backends
254
- wrapped = _npfunc_wrapper (func , multi_output , xp )
283
+ wrapped = _apply_lazy_wrapper (func , as_numpy , multi_output , xp )
255
284
out = wrapped (* args , ** kwargs )
256
285
257
286
return out if multi_output else out [0 ]
258
287
259
288
260
- def _npfunc_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
261
- func : Callable [..., NumPyObject | Sequence [NumPyObject ]],
289
+ def _apply_lazy_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
290
+ func : Callable [..., ArrayLike | Sequence [ArrayLike ]],
291
+ as_numpy : bool ,
262
292
multi_output : bool ,
263
293
xp : ModuleType ,
264
294
) -> Callable [..., tuple [Array , ...]]:
265
295
"""
266
- Helper of `apply_numpy_func `.
296
+ Helper of `apply_lazy `.
267
297
268
298
Given a function that accepts one or more numpy arrays as positional arguments and
269
299
returns a single numpy array or a sequence of numpy arrays, return a function that
@@ -284,19 +314,13 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
284
314
) -> tuple [Array , ...]: # numpydoc ignore=GL08
285
315
import numpy as np # pylint: disable=import-outside-toplevel
286
316
287
- args = tuple (np .asarray (arg ) for arg in args )
317
+ if as_numpy :
318
+ args = tuple (np .asarray (arg ) for arg in args )
288
319
out = func (* args , ** kwargs )
289
320
290
- # Stay relaxed on output validation, e.g. in case func returns a
291
- # Python scalar instead of a np.generic
292
321
if multi_output :
293
- if not isinstance (out , Sequence ) or isinstance (out , np .ndarray ):
294
- msg = "Expected multiple outputs, got a single one"
295
- raise ValueError (msg )
296
- outs = out
297
- else :
298
- outs = [cast ("NumPyObject" , out )]
299
-
300
- return tuple (xp .asarray (o ) for o in outs )
322
+ assert isinstance (out , Sequence )
323
+ return tuple (xp .asarray (o ) for o in out )
324
+ return (xp .asarray (out ),)
301
325
302
326
return wrapper
0 commit comments