3
3
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
4
4
from __future__ import annotations
5
5
6
+ import math
6
7
from collections .abc import Callable , Sequence
7
8
from functools import wraps
8
9
from types import ModuleType
@@ -30,11 +31,19 @@ class P: # pylint: disable=missing-class-docstring
30
31
kwargs : dict
31
32
32
33
34
+ class UnknownShapeError (ValueError ):
35
+ """
36
+ `shape` contains one or more None elements.
37
+
38
+ This is unsupported when running inside `jax.jit`.
39
+ """
40
+
41
+
33
42
@overload
34
43
def apply_numpy_func ( # type: ignore[valid-type]
35
44
func : Callable [P , NumPyObject ],
36
45
* args : Array ,
37
- shape : tuple [int , ...] | None = None ,
46
+ shape : tuple [int | None , ...] | None = None ,
38
47
dtype : DType | None = None ,
39
48
xp : ModuleType | None = None ,
40
49
** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
@@ -45,7 +54,7 @@ def apply_numpy_func( # type: ignore[valid-type]
45
54
def apply_numpy_func ( # type: ignore[valid-type]
46
55
func : Callable [P , Sequence [NumPyObject ]],
47
56
* args : Array ,
48
- shape : Sequence [tuple [int , ...]],
57
+ shape : Sequence [tuple [int | None , ...]],
49
58
dtype : Sequence [DType ] | None = None ,
50
59
xp : ModuleType | None = None ,
51
60
** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
@@ -55,7 +64,7 @@ def apply_numpy_func( # type: ignore[valid-type]
55
64
def apply_numpy_func ( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
56
65
func : Callable [P , NumPyObject | Sequence [NumPyObject ]],
57
66
* args : Array ,
58
- shape : tuple [int , ...] | Sequence [tuple [int , ...]] | None = None ,
67
+ shape : tuple [int | None , ...] | Sequence [tuple [int | None , ...]] | None = None ,
59
68
dtype : DType | Sequence [DType ] | None = None ,
60
69
xp : ModuleType | None = None ,
61
70
** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
@@ -76,7 +85,7 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
76
85
One or more Array API compliant arrays. You need to be able to apply
77
86
:func:`numpy.asarray` to them to convert them to numpy; read notes below about
78
87
specific backends.
79
- shape : tuple[int, ...] | Sequence[tuple[int, ...]], optional
88
+ shape : tuple[int | None , ...] | Sequence[tuple[int, ...]], optional
80
89
Output shape or sequence of output shapes, one for each output of `func`.
81
90
Default: assume single output and broadcast shapes of the input arrays.
82
91
dtype : DType | Sequence[DType], optional
@@ -102,6 +111,8 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
102
111
JAX
103
112
This allows applying eager functions to jitted JAX arrays, which are lazy.
104
113
The function won't be applied until the JAX array is materialized.
114
+ When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot
115
+ contain any `None` elements.
105
116
106
117
The :doc:`jax:transfer_guard` may prevent arrays on a GPU device from being
107
118
transferred back to CPU. This is treated as an implicit transfer.
@@ -135,6 +146,18 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
135
146
:func:`dask.array.blockwise`, or a native Dask wrapper instead of
136
147
`apply_numpy_func`.
137
148
149
+ Raises
150
+ ------
151
+ UnknownShapeError
152
+ When `shape` is unknown (one or more sizes are None) and this function was
153
+ called inside `jax.jit`.
154
+
155
+ Exception (varies)
156
+
157
+ - When the backend disallows implicit device to host transfers and the input
158
+ arrays are on a device, e.g. on GPU;
159
+ - When the backend is sparse and auto-densification is disabled.
160
+
138
161
See Also
139
162
--------
140
163
jax.transfer_guard
@@ -147,13 +170,16 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
147
170
xp = array_namespace (* args )
148
171
149
172
# Normalize and validate shape and dtype
173
+ shapes : list [tuple [int | None , ...]]
174
+ dtypes : list [DType ]
150
175
multi_output = False
176
+
151
177
if shape is None :
152
178
shapes = [xp .broadcast_shapes (* (arg .shape for arg in args ))]
153
- elif isinstance (shape , tuple ) and all (isinstance (s , int ) for s in shape ):
154
- shapes = [shape ]
179
+ elif isinstance (shape , tuple ) and all (isinstance (s , int | None ) for s in shape ):
180
+ shapes = [shape ] # pyright: ignore[reportAssignmentType]
155
181
else :
156
- shapes = list (shape )
182
+ shapes = list (shape ) # type: ignore[arg-type] # pyright: ignore[reportAssignmentType]
157
183
multi_output = True
158
184
159
185
if dtype is None :
@@ -186,13 +212,19 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
186
212
meta_xp = array_namespace (* metas )
187
213
188
214
wrapped = dask .delayed (_npfunc_wrapper (func , multi_output , meta_xp ), pure = True )
189
- # This finalizes each arg, which is the same as arg.rechunk(-1)
215
+ # This finalizes each arg, which is the same as arg.rechunk(-1).
190
216
# Please read docstring above for why we're not using
191
217
# dask.array.map_blocks or dask.array.blockwise!
192
218
delayed_out = wrapped (* args , ** kwargs )
193
219
194
220
out = tuple (
195
- xp .from_delayed (delayed_out [i ], shape = shape , dtype = dtype , meta = metas [0 ])
221
+ xp .from_delayed (
222
+ delayed_out [i ],
223
+ # Dask's unknown shapes diverge from the Array API specification
224
+ shape = tuple (math .nan if s is None else s for s in shape ),
225
+ dtype = dtype ,
226
+ meta = metas [0 ],
227
+ )
196
228
for i , (shape , dtype ) in enumerate (zip (shapes , dtypes , strict = True ))
197
229
)
198
230
@@ -205,18 +237,33 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
205
237
import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
206
238
207
239
wrapped = _npfunc_wrapper (func , multi_output , xp )
208
- out = cast (
209
- tuple [Array , ...],
210
- jax .pure_callback (
211
- wrapped ,
212
- tuple (
213
- jax .ShapeDtypeStruct (s , dt ) # pyright: ignore[reportUnknownArgumentType]
214
- for s , dt in zip (shapes , dtypes , strict = True )
240
+
241
+ if any (s is None for shape in shapes for s in shape ):
242
+ # Unknown output shape. Won't work with jax.jit, but it
243
+ # can work with eager jax.
244
+ try :
245
+ out = wrapped (* args , ** kwargs )
246
+ except jax .errors .TracerArrayConversionError :
247
+ msg = (
248
+ "jax.jit can't delay application of numpy functions when the shape "
249
+ "of the returned array(s) is unknown. "
250
+ f"shape={ shapes if multi_output else shapes [0 ]} "
251
+ )
252
+ raise UnknownShapeError (msg ) from None
253
+
254
+ else :
255
+ out = cast (
256
+ tuple [Array , ...],
257
+ jax .pure_callback (
258
+ wrapped ,
259
+ tuple (
260
+ jax .ShapeDtypeStruct (shape , dtype ) # pyright: ignore[reportUnknownArgumentType]
261
+ for shape , dtype in zip (shapes , dtypes , strict = True )
262
+ ),
263
+ * args ,
264
+ ** kwargs ,
215
265
),
216
- * args ,
217
- ** kwargs ,
218
- ),
219
- )
266
+ )
220
267
221
268
else :
222
269
# Eager backends
0 commit comments