From e5aa2807098e3931d6ae14a6a23a13a6b06f7aa5 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 16 May 2025 22:40:30 +0100 Subject: [PATCH 1/3] DOC: `autojit` notes --- src/array_api_extra/_lib/_utils/_helpers.py | 19 +++++++++++++++++++ src/array_api_extra/testing.py | 1 + 2 files changed, 20 insertions(+) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 99fcf65a..b9dd3e66 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -540,6 +540,25 @@ def jax_autojit( See Also -------- jax.jit : JAX JIT compilation function. + + Notes + ----- + These are useful choices *for testing purposes only*, which is how this function is + intended to be used. The output of ``jax.jit`` is a C++ level callable, that + directly dispatches to the compiled kernel after the initial call. In comparison, + ``jax_autojit`` incurs in a much higher dispatch time. + + Additionally, consider:: + + def f(x: Array, y: float, plus: bool) -> Array: + return x + y if plus else x - y + + j1 = jax.jit(f, static_argnames="plus") + j2 = jax_autojit(f) + + In the above example, ``j2`` requires a lot less setup to be tested effectively than + ``j1``, but on the flip side it means that it will be re-traced for every different + value of ``y``, which likely makes it not fit for purpose in production. """ import jax diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index c14e9a22..3979f9dd 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -96,6 +96,7 @@ def lazy_xp_function( # type: ignore[explicit-any] jax_jit : bool, optional Set to True to replace `func` with a smart variant of ``jax.jit(func)`` after calling the :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. + This is the default behaviour. Set to False if `func` is only compatible with eager (non-jitted) JAX. Unlike with vanilla ``jax.jit``, all arguments and return types that are not JAX From 87cccd2f14c2f01238d17521787d3f51691123b8 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 16 May 2025 23:19:29 +0100 Subject: [PATCH 2/3] nit --- src/array_api_extra/_lib/_utils/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index b9dd3e66..1e91f07b 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -463,7 +463,7 @@ def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # Notes ----- The `instances` iterable must yield at least the same number of elements as the ones - returned by ``pickle_without``, but the elements do not need to be the same objects + returned by ``pickle_flatten``, but the elements do not need to be the same objects or even the same types of objects. Excess elements, if any, will be left untouched. """ iters = iter(instances), iter(rest) From bd8ecfabd52cbcda6d2c4a39942bf088c0a793ad Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Sat, 17 May 2025 13:33:41 +0200 Subject: [PATCH 3/3] fix typo --- src/array_api_extra/_lib/_utils/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 1e91f07b..b856eb41 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -546,7 +546,7 @@ def jax_autojit( These are useful choices *for testing purposes only*, which is how this function is intended to be used. The output of ``jax.jit`` is a C++ level callable, that directly dispatches to the compiled kernel after the initial call. In comparison, - ``jax_autojit`` incurs in a much higher dispatch time. + ``jax_autojit`` incurs a much higher dispatch time. Additionally, consider::