@@ -31,14 +31,6 @@ class P: # pylint: disable=missing-class-docstring
31
31
kwargs : dict
32
32
33
33
34
- class UnknownShapeError (ValueError ):
35
- """
36
- `shape` contains one or more None elements.
37
-
38
- This is unsupported when running inside `jax.jit`.
39
- """
40
-
41
-
42
34
@overload
43
35
def apply_numpy_func ( # type: ignore[valid-type]
44
36
func : Callable [P , NumPyObject ],
@@ -148,15 +140,14 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
148
140
149
141
Raises
150
142
------
151
- UnknownShapeError
152
- When `shape` is unknown (one or more sizes are None) and this function was
153
- called inside `jax.jit`.
154
-
155
- Exception (varies)
156
-
157
- - When the backend disallows implicit device to host transfers and the input
158
- arrays are on a device, e.g. on GPU;
159
- - When the backend is sparse and auto-densification is disabled.
143
+ jax.errors.TracerArrayConversionError
144
+ When `xp=jax.numpy`, `shape` is unknown (it contains None on one or more axes)
145
+ and this function was called inside `jax.jit`.
146
+ RuntimeError
147
+ When `xp=sparse` and auto-densification is disabled.
148
+ Exception (backend-specific)
149
+ When the backend disallows implicit device to host transfers and the input
150
+ arrays are on a device, e.g. on GPU.
160
151
161
152
See Also
162
153
--------
@@ -241,15 +232,8 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
241
232
if any (s is None for shape in shapes for s in shape ):
242
233
# Unknown output shape. Won't work with jax.jit, but it
243
234
# can work with eager jax.
244
- try :
245
- out = wrapped (* args , ** kwargs )
246
- except jax .errors .TracerArrayConversionError :
247
- msg = (
248
- "jax.jit can't delay application of numpy functions when the shape "
249
- "of the returned array(s) is unknown. "
250
- f"shape={ shapes if multi_output else shapes [0 ]} "
251
- )
252
- raise UnknownShapeError (msg ) from None
235
+ # Raises jax.errors.TracerArrayConversionError if we're inside jax.jit.
236
+ out = wrapped (* args , ** kwargs )
253
237
254
238
else :
255
239
out = cast (
0 commit comments