Skip to content

Commit ad18b16

Browse files
committed
Revert bespoke exception
1 parent 35e46fd commit ad18b16

File tree

3 files changed

+11
-29
lines changed

3 files changed

+11
-29
lines changed

docs/api-reference.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
:toctree: generated
88
99
apply_numpy_func
10-
UnknownShapeError
1110
at
1211
atleast_nd
1312
cov

src/array_api_extra/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Extra array functions built on top of the array API standard."""
22

33
from ._delegation import pad
4-
from ._lib._apply import UnknownShapeError, apply_numpy_func
4+
from ._lib._apply import apply_numpy_func
55
from ._lib._funcs import (
66
at,
77
atleast_nd,
@@ -18,7 +18,6 @@
1818

1919
# pylint: disable=duplicate-code
2020
__all__ = [
21-
"UnknownShapeError",
2221
"__version__",
2322
"apply_numpy_func",
2423
"at",

src/array_api_extra/_lib/_apply.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,6 @@ class P: # pylint: disable=missing-class-docstring
3131
kwargs: dict
3232

3333

34-
class UnknownShapeError(ValueError):
35-
"""
36-
`shape` contains one or more None elements.
37-
38-
This is unsupported when running inside `jax.jit`.
39-
"""
40-
41-
4234
@overload
4335
def apply_numpy_func( # type: ignore[valid-type]
4436
func: Callable[P, NumPyObject],
@@ -148,15 +140,14 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
148140
149141
Raises
150142
------
151-
UnknownShapeError
152-
When `shape` is unknown (one or more sizes are None) and this function was
153-
called inside `jax.jit`.
154-
155-
Exception (varies)
156-
157-
- When the backend disallows implicit device to host transfers and the input
158-
arrays are on a device, e.g. on GPU;
159-
- When the backend is sparse and auto-densification is disabled.
143+
jax.errors.TracerArrayConversionError
144+
When `xp=jax.numpy`, `shape` is unknown (it contains None on one or more axes)
145+
and this function was called inside `jax.jit`.
146+
RuntimeError
147+
When `xp=sparse` and auto-densification is disabled.
148+
Exception (backend-specific)
149+
When the backend disallows implicit device to host transfers and the input
150+
arrays are on a device, e.g. on GPU.
160151
161152
See Also
162153
--------
@@ -241,15 +232,8 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
241232
if any(s is None for shape in shapes for s in shape):
242233
# Unknown output shape. Won't work with jax.jit, but it
243234
# can work with eager jax.
244-
try:
245-
out = wrapped(*args, **kwargs)
246-
except jax.errors.TracerArrayConversionError:
247-
msg = (
248-
"jax.jit can't delay application of numpy functions when the shape "
249-
"of the returned array(s) is unknown. "
250-
f"shape={shapes if multi_output else shapes[0]}"
251-
)
252-
raise UnknownShapeError(msg) from None
235+
# Raises jax.errors.TracerArrayConversionError if we're inside jax.jit.
236+
out = wrapped(*args, **kwargs)
253237

254238
else:
255239
out = cast(

0 commit comments

Comments
 (0)