Skip to content

Commit df2ee4b

Browse files
committed
Rename to apply_lazy; add as_numpy=False param
1 parent ad18b16 commit df2ee4b

File tree

3 files changed

+66
-42
lines changed

3 files changed

+66
-42
lines changed

docs/api-reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9-
apply_numpy_func
9+
apply_lazy
1010
at
1111
atleast_nd
1212
cov

src/array_api_extra/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Extra array functions built on top of the array API standard."""
22

33
from ._delegation import pad
4-
from ._lib._apply import apply_numpy_func
4+
from ._lib._apply import apply_lazy
55
from ._lib._funcs import (
66
at,
77
atleast_nd,
@@ -19,7 +19,7 @@
1919
# pylint: disable=duplicate-code
2020
__all__ = [
2121
"__version__",
22-
"apply_numpy_func",
22+
"apply_lazy",
2323
"at",
2424
"atleast_nd",
2525
"cov",

src/array_api_extra/_lib/_apply.py

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import ParamSpec, TypeAlias
2020

2121
import numpy as np
22+
from numpy.typing import ArrayLike
2223

2324
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[no-any-explicit]
2425
P = ParamSpec("P")
@@ -32,58 +33,74 @@ class P: # pylint: disable=missing-class-docstring
3233

3334

3435
@overload
35-
def apply_numpy_func( # type: ignore[valid-type]
36-
func: Callable[P, NumPyObject],
36+
def apply_lazy( # type: ignore[valid-type]
37+
func: Callable[P, ArrayLike],
3738
*args: Array,
3839
shape: tuple[int | None, ...] | None = None,
3940
dtype: DType | None = None,
41+
as_numpy: bool = False,
4042
xp: ModuleType | None = None,
4143
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
4244
) -> Array: ... # numpydoc ignore=GL08
4345

4446

4547
@overload
46-
def apply_numpy_func( # type: ignore[valid-type]
47-
func: Callable[P, Sequence[NumPyObject]],
48+
def apply_lazy( # type: ignore[valid-type]
49+
func: Callable[P, Sequence[ArrayLike]],
4850
*args: Array,
4951
shape: Sequence[tuple[int | None, ...]],
5052
dtype: Sequence[DType] | None = None,
53+
as_numpy: bool = False,
5154
xp: ModuleType | None = None,
5255
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
5356
) -> tuple[Array, ...]: ... # numpydoc ignore=GL08
5457

5558

56-
def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
57-
func: Callable[P, NumPyObject | Sequence[NumPyObject]],
59+
def apply_lazy( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
60+
func: Callable[P, Array | Sequence[ArrayLike]],
5861
*args: Array,
5962
shape: tuple[int | None, ...] | Sequence[tuple[int | None, ...]] | None = None,
6063
dtype: DType | Sequence[DType] | None = None,
64+
as_numpy: bool = False,
6165
xp: ModuleType | None = None,
6266
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
6367
) -> Array | tuple[Array, ...]:
6468
"""
65-
Apply a function that operates on NumPy arrays to Array API compliant arrays.
69+
Lazily apply an eager function.
70+
71+
If the backend of the input arrays is lazy, e.g. Dask or jitted JAX, the execution
72+
of the function is delayed until the graph is materialized; if it's eager, the
73+
function is executed immediately.
6674
6775
Parameters
6876
----------
6977
func : callable
70-
The function to apply. It must accept one or more NumPy arrays or generics as
71-
positional arguments and return either a single NumPy array or generic, or a
72-
tuple or list thereof.
78+
The function to apply.
79+
80+
It must accept one or more array API compliant arrays as positional arguments.
81+
If `as_numpy=True`, inputs are converted to NumPy before they are passed to
82+
`func`.
83+
It must return either a single array-like or a sequence of array-likes.
7384
74-
It must be a pure function, i.e. without side effects, as depending on the
85+
`func` must be a pure function, i.e. without side effects, as depending on the
7586
backend it may be executed more than once.
7687
*args : Array
77-
One or more Array API compliant arrays. You need to be able to apply
78-
:func:`numpy.asarray` to them to convert them to numpy; read notes below about
79-
specific backends.
88+
One or more Array API compliant arrays.
89+
90+
If `as_numpy=True`, you need to be able to apply :func:`numpy.asarray` to them
91+
to convert them to numpy; read notes below about specific backends.
8092
shape : tuple[int | None, ...] | Sequence[tuple[int, ...]], optional
8193
Output shape or sequence of output shapes, one for each output of `func`.
8294
Default: assume single output and broadcast shapes of the input arrays.
8395
dtype : DType | Sequence[DType], optional
8496
Output dtype or sequence of output dtypes, one for each output of `func`.
8597
dtype(s) must belong to the same array namespace as the input arrays.
8698
Default: infer the result type(s) from the input arrays.
99+
as_numpy : bool, optional
100+
If True, convert the input arrays to NumPy before passing them to `func`.
101+
This is particularly useful to make numpy-only functions, e.g. written in Cython
102+
or Numba, work transparently API arrays.
103+
Default: False.
87104
xp : array_namespace, optional
88105
The standard-compatible namespace for `args`. Default: infer.
89106
**kwargs : Any, optional
@@ -95,7 +112,7 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
95112
Array | tuple[Array, ...]
96113
The result(s) of `func` applied to the input arrays, wrapped in the same
97114
array namespace as the inputs.
98-
If shape is omitted or a `tuple[int, ...]`, this is a single array.
115+
If shape is omitted or a `tuple[int | None, ...]`, this is a single array.
99116
Otherwise, it's a tuple of arrays.
100117
101118
Notes
@@ -106,23 +123,26 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
106123
When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot
107124
contain any `None` elements.
108125
109-
The :doc:`jax:transfer_guard` may prevent arrays on a GPU device from being
110-
transferred back to CPU. This is treated as an implicit transfer.
126+
Using this with `as_numpy=False` is particularly useful to apply non-jittable
127+
JAX functions to arrays on GPU devices.
128+
If `as_numpy=True`, the :doc:`jax:transfer_guard` may prevent arrays on a GPU
129+
device from being transferred back to CPU. This is treated as an implicit
130+
transfer.
111131
112132
PyTorch, CuPy
113-
These backends raise by default if you attempt to convert arrays on a GPU device
114-
to NumPy.
133+
If `as_numpy=True`, these backends raise by default if you attempt to convert
134+
arrays on a GPU device to NumPy.
115135
116136
Sparse
117-
By default, sparse prevents implicit densification through
137+
If `as_numpy=True`, by default sparse prevents implicit densification through
118138
:func:`numpy.asarray`. `This safety mechanism can be disabled
119139
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
120140
121141
Dask
122142
This allows applying eager functions to dask arrays.
123143
The dask graph won't be computed.
124144
125-
`apply_numpy_func` doesn't know if `func` reduces along any axes; also, shape
145+
`apply_lazy` doesn't know if `func` reduces along any axes; also, shape
126146
changes are non-trivial in chunked Dask arrays. For these reasons, all inputs
127147
will be rechunked into a single chunk.
128148
@@ -136,7 +156,13 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
136156
If you want to distribute the calculation across multiple workers, you
137157
should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
138158
:func:`dask.array.blockwise`, or a native Dask wrapper instead of
139-
`apply_numpy_func`.
159+
`apply_lazy`.
160+
161+
Dask wrapping around other backends
162+
If `as_numpy=False`, `func` will receive in input eager arrays of the meta
163+
namespace, as defined by the `._meta` attribute of the input Dask arrays.
164+
The outputs of `func` will be wrapped by the meta namespace, and then wrapped
165+
again by Dask.
140166
141167
Raises
142168
------
@@ -202,7 +228,10 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
202228
metas = [arg._meta for arg in args if hasattr(arg, "_meta")] # pylint: disable=protected-access
203229
meta_xp = array_namespace(*metas)
204230

205-
wrapped = dask.delayed(_npfunc_wrapper(func, multi_output, meta_xp), pure=True)
231+
wrapped = dask.delayed(
232+
_apply_lazy_wrapper(func, as_numpy, multi_output, meta_xp),
233+
pure=True,
234+
)
206235
# This finalizes each arg, which is the same as arg.rechunk(-1).
207236
# Please read docstring above for why we're not using
208237
# dask.array.map_blocks or dask.array.blockwise!
@@ -227,7 +256,7 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
227256

228257
import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
229258

230-
wrapped = _npfunc_wrapper(func, multi_output, xp)
259+
wrapped = _apply_lazy_wrapper(func, as_numpy, multi_output, xp)
231260

232261
if any(s is None for shape in shapes for s in shape):
233262
# Unknown output shape. Won't work with jax.jit, but it
@@ -251,19 +280,20 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
251280

252281
else:
253282
# Eager backends
254-
wrapped = _npfunc_wrapper(func, multi_output, xp)
283+
wrapped = _apply_lazy_wrapper(func, as_numpy, multi_output, xp)
255284
out = wrapped(*args, **kwargs)
256285

257286
return out if multi_output else out[0]
258287

259288

260-
def _npfunc_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
261-
func: Callable[..., NumPyObject | Sequence[NumPyObject]],
289+
def _apply_lazy_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
290+
func: Callable[..., ArrayLike | Sequence[ArrayLike]],
291+
as_numpy: bool,
262292
multi_output: bool,
263293
xp: ModuleType,
264294
) -> Callable[..., tuple[Array, ...]]:
265295
"""
266-
Helper of `apply_numpy_func`.
296+
Helper of `apply_lazy`.
267297
268298
Given a function that accepts one or more numpy arrays as positional arguments and
269299
returns a single numpy array or a sequence of numpy arrays, return a function that
@@ -284,19 +314,13 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
284314
) -> tuple[Array, ...]: # numpydoc ignore=GL08
285315
import numpy as np # pylint: disable=import-outside-toplevel
286316

287-
args = tuple(np.asarray(arg) for arg in args)
317+
if as_numpy:
318+
args = tuple(np.asarray(arg) for arg in args)
288319
out = func(*args, **kwargs)
289320

290-
# Stay relaxed on output validation, e.g. in case func returns a
291-
# Python scalar instead of a np.generic
292321
if multi_output:
293-
if not isinstance(out, Sequence) or isinstance(out, np.ndarray):
294-
msg = "Expected multiple outputs, got a single one"
295-
raise ValueError(msg)
296-
outs = out
297-
else:
298-
outs = [cast("NumPyObject", out)]
299-
300-
return tuple(xp.asarray(o) for o in outs)
322+
assert isinstance(out, Sequence)
323+
return tuple(xp.asarray(o) for o in out)
324+
return (xp.asarray(out),)
301325

302326
return wrapper

0 commit comments

Comments
 (0)