Skip to content

Commit d90682a

Browse files
committed
Support unknown shapes
1 parent 2e0852d commit d90682a

File tree

3 files changed

+70
-21
lines changed

3 files changed

+70
-21
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
:toctree: generated
88
99
apply_numpy_func
10+
UnknownShapeError
1011
at
1112
atleast_nd
1213
cov

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Extra array functions built on top of the array API standard."""
22

33
from ._delegation import pad
4-
from ._lib._apply import apply_numpy_func
4+
from ._lib._apply import UnknownShapeError, apply_numpy_func
55
from ._lib._funcs import (
66
at,
77
atleast_nd,
@@ -18,6 +18,7 @@
1818

1919
# pylint: disable=duplicate-code
2020
__all__ = [
21+
"UnknownShapeError",
2122
"__version__",
2223
"apply_numpy_func",
2324
"at",

src/array_api_extra/_lib/_apply.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44
from __future__ import annotations
55

6+
import math
67
from collections.abc import Callable, Sequence
78
from functools import wraps
89
from types import ModuleType
@@ -30,11 +31,19 @@ class P: # pylint: disable=missing-class-docstring
3031
kwargs: dict
3132

3233

34+
class UnknownShapeError(ValueError):
35+
"""
36+
`shape` contains one or more None elements.
37+
38+
This is unsupported when running inside `jax.jit`.
39+
"""
40+
41+
3342
@overload
3443
def apply_numpy_func( # type: ignore[valid-type]
3544
func: Callable[P, NumPyObject],
3645
*args: Array,
37-
shape: tuple[int, ...] | None = None,
46+
shape: tuple[int | None, ...] | None = None,
3847
dtype: DType | None = None,
3948
xp: ModuleType | None = None,
4049
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
@@ -45,7 +54,7 @@ def apply_numpy_func( # type: ignore[valid-type]
4554
def apply_numpy_func( # type: ignore[valid-type]
4655
func: Callable[P, Sequence[NumPyObject]],
4756
*args: Array,
48-
shape: Sequence[tuple[int, ...]],
57+
shape: Sequence[tuple[int | None, ...]],
4958
dtype: Sequence[DType] | None = None,
5059
xp: ModuleType | None = None,
5160
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
@@ -55,7 +64,7 @@ def apply_numpy_func( # type: ignore[valid-type]
5564
def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
5665
func: Callable[P, NumPyObject | Sequence[NumPyObject]],
5766
*args: Array,
58-
shape: tuple[int, ...] | Sequence[tuple[int, ...]] | None = None,
67+
shape: tuple[int | None, ...] | Sequence[tuple[int | None, ...]] | None = None,
5968
dtype: DType | Sequence[DType] | None = None,
6069
xp: ModuleType | None = None,
6170
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
@@ -76,7 +85,7 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
7685
One or more Array API compliant arrays. You need to be able to apply
7786
:func:`numpy.asarray` to them to convert them to numpy; read notes below about
7887
specific backends.
79-
shape : tuple[int, ...] | Sequence[tuple[int, ...]], optional
88+
shape : tuple[int | None, ...] | Sequence[tuple[int, ...]], optional
8089
Output shape or sequence of output shapes, one for each output of `func`.
8190
Default: assume single output and broadcast shapes of the input arrays.
8291
dtype : DType | Sequence[DType], optional
@@ -102,6 +111,8 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
102111
JAX
103112
This allows applying eager functions to jitted JAX arrays, which are lazy.
104113
The function won't be applied until the JAX array is materialized.
114+
When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot
115+
contain any `None` elements.
105116
106117
The :doc:`jax:transfer_guard` may prevent arrays on a GPU device from being
107118
transferred back to CPU. This is treated as an implicit transfer.
@@ -135,6 +146,18 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
135146
:func:`dask.array.blockwise`, or a native Dask wrapper instead of
136147
`apply_numpy_func`.
137148
149+
Raises
150+
------
151+
UnknownShapeError
152+
When `shape` is unknown (one or more sizes are None) and this function was
153+
called inside `jax.jit`.
154+
155+
Exception (varies)
156+
157+
- When the backend disallows implicit device to host transfers and the input
158+
arrays are on a device, e.g. on GPU;
159+
- When the backend is sparse and auto-densification is disabled.
160+
138161
See Also
139162
--------
140163
jax.transfer_guard
@@ -147,13 +170,16 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
147170
xp = array_namespace(*args)
148171

149172
# Normalize and validate shape and dtype
173+
shapes: list[tuple[int | None, ...]]
174+
dtypes: list[DType]
150175
multi_output = False
176+
151177
if shape is None:
152178
shapes = [xp.broadcast_shapes(*(arg.shape for arg in args))]
153-
elif isinstance(shape, tuple) and all(isinstance(s, int) for s in shape):
154-
shapes = [shape]
179+
elif isinstance(shape, tuple) and all(isinstance(s, int | None) for s in shape):
180+
shapes = [shape] # pyright: ignore[reportAssignmentType]
155181
else:
156-
shapes = list(shape)
182+
shapes = list(shape) # type: ignore[arg-type] # pyright: ignore[reportAssignmentType]
157183
multi_output = True
158184

159185
if dtype is None:
@@ -186,13 +212,19 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
186212
meta_xp = array_namespace(*metas)
187213

188214
wrapped = dask.delayed(_npfunc_wrapper(func, multi_output, meta_xp), pure=True)
189-
# This finalizes each arg, which is the same as arg.rechunk(-1)
215+
# This finalizes each arg, which is the same as arg.rechunk(-1).
190216
# Please read docstring above for why we're not using
191217
# dask.array.map_blocks or dask.array.blockwise!
192218
delayed_out = wrapped(*args, **kwargs)
193219

194220
out = tuple(
195-
xp.from_delayed(delayed_out[i], shape=shape, dtype=dtype, meta=metas[0])
221+
xp.from_delayed(
222+
delayed_out[i],
223+
# Dask's unknown shapes diverge from the Array API specification
224+
shape=tuple(math.nan if s is None else s for s in shape),
225+
dtype=dtype,
226+
meta=metas[0],
227+
)
196228
for i, (shape, dtype) in enumerate(zip(shapes, dtypes, strict=True))
197229
)
198230

@@ -205,18 +237,33 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
205237
import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
206238

207239
wrapped = _npfunc_wrapper(func, multi_output, xp)
208-
out = cast(
209-
tuple[Array, ...],
210-
jax.pure_callback(
211-
wrapped,
212-
tuple(
213-
jax.ShapeDtypeStruct(s, dt) # pyright: ignore[reportUnknownArgumentType]
214-
for s, dt in zip(shapes, dtypes, strict=True)
240+
241+
if any(s is None for shape in shapes for s in shape):
242+
# Unknown output shape. Won't work with jax.jit, but it
243+
# can work with eager jax.
244+
try:
245+
out = wrapped(*args, **kwargs)
246+
except jax.errors.TracerArrayConversionError:
247+
msg = (
248+
"jax.jit can't delay application of numpy functions when the shape "
249+
"of the returned array(s) is unknown. "
250+
f"shape={shapes if multi_output else shapes[0]}"
251+
)
252+
raise UnknownShapeError(msg) from None
253+
254+
else:
255+
out = cast(
256+
tuple[Array, ...],
257+
jax.pure_callback(
258+
wrapped,
259+
tuple(
260+
jax.ShapeDtypeStruct(shape, dtype) # pyright: ignore[reportUnknownArgumentType]
261+
for shape, dtype in zip(shapes, dtypes, strict=True)
262+
),
263+
*args,
264+
**kwargs,
215265
),
216-
*args,
217-
**kwargs,
218-
),
219-
)
266+
)
220267

221268
else:
222269
# Eager backends

0 commit comments

Comments
 (0)