Skip to content

Commit 8f36e4e

Browse files
committed
Merge branch 'main' into dask-new
[skip cirrus] [skip circle]
2 parents e40dc94 + 54d1876 commit 8f36e4e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+2709
-1243
lines changed

doc/source/dev/api-dev/array_api.rst

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,32 @@ Note that there is a GitHub Actions workflow which tests with array-api-strict,
331331
PyTorch, and JAX on CPU.
332332

333333

334+
Testing the JAX JIT compiler
335+
----------------------------
336+
The `JAX JIT compiler <https://jax.readthedocs.io/en/latest/jit-compilation.html>`_
337+
introduces special restrictions to all code wrapped by `@jax.jit`, which are not
338+
present when running JAX in eager mode. Notably, boolean masks in `__getitem__`
339+
and `.at` aren't supported, and you can't materialize the arrays by applying
340+
`bool()`, `float()`, `np.asarray()` etc. to them.
341+
342+
To properly test scipy with JAX, you need to wrap the tested scipy functions
343+
with `@jax.jit` before they are called by the unit tests.
344+
To achieve this, you should tag them as follows in your test module::
345+
346+
from scipy._lib._lazy_testing import lazy_xp_function
347+
from scipy.mymodule import toto
348+
349+
lazy_xp_function(toto)
350+
351+
def test_toto(xp):
352+
a = xp.asarray([1, 2, 3])
353+
b = xp.asarray([0, 2, 5])
354+
# When xp==jax.numpy, toto is wrapped with @jax.jit
355+
xp_assert_close(toto(a, b), a)
356+
357+
See full documentation in `scipy/_lib/_lazy_testing.py`.
358+
359+
334360
Additional information
335361
----------------------
336362

mypy.ini

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ ignore_missing_imports = True
8888
[mypy-array_api_strict]
8989
ignore_missing_imports = True
9090

91+
[mypy-jax]
92+
# Typed, but cumbersome to install in CI (depends on scipy)
93+
ignore_missing_imports = True
94+
9195
[mypy-sphinx.*]
9296
ignore_missing_imports = True
9397

@@ -751,3 +755,6 @@ ignore_errors = True
751755

752756
[mypy-scipy._lib.array_api_compat.*]
753757
ignore_errors = True
758+
759+
[mypy-scipy._lib.array_api_extra.*]
760+
ignore_errors = True

scipy/_lib/_lazy_testing.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from collections.abc import Callable, Iterable, Sequence
2+
from types import ModuleType
3+
import pytest
4+
from scipy._lib._array_api import is_jax
5+
6+
7+
def lazy_xp_function(
8+
func: Callable,
9+
*,
10+
jax_jit: bool = True,
11+
static_argnums: int | Sequence[int] | None=None,
12+
static_argnames: str | Iterable[str] | None=None,
13+
) -> None:
14+
"""Tag a function, which must be imported in the test module globals,
15+
so that when any tests defined in the same module are executed with
16+
xp=jax.numpy the function is replaced with a jitted version of itself.
17+
18+
This will be later expanded to provide test coverage for other lazy backends,
19+
e.g. Dask.
20+
21+
Example::
22+
23+
# test_mymodule.py:
24+
from scipy._lib._lazy_testing import lazy_xp_function
25+
from scipy.mymodule import myfunc
26+
27+
lazy_xp_function(myfunc)
28+
29+
def test_myfunc(xp):
30+
a = xp.asarray([1, 2])
31+
# When xp=jax.numpy, this is the same as
32+
# b = jax.jit(myfunc)(a)
33+
b = myfunc(a)
34+
35+
Parameters
36+
----------
37+
func : callable
38+
Function to be tested
39+
jax_jit : bool, optional
40+
Set to True to replace `func` with `jax.jit(func)` when calling the
41+
`patch_lazy_xp_functions` test helper with `xp=jax.numpy`.
42+
Set to False if `func` is only compatible with eager (non-jitted) JAX.
43+
Default: True.
44+
static_argnums : int | Sequence[int], optional
45+
Passed to jax.jit.
46+
Positional arguments to treat as static (trace- and compile-time constant).
47+
Default: infer from static_argnames using `inspect.signature(func)`.
48+
static_argnames : str | Iterable[str], optional
49+
Passed to jax.jit.
50+
Named arguments to treat as static (compile-time constant).
51+
Default: infer from static_argnums using `inspect.signature(func)`.
52+
53+
Notes
54+
-----
55+
A test function can circumvent this monkey-patching system by calling `func` an
56+
attribute of the original module. You need to sanitize your code to
57+
make sure this does not happen.
58+
59+
Example::
60+
61+
import mymodule
62+
from mymodule import myfunc
63+
64+
lazy_xp_function(myfunc)
65+
66+
def test_myfunc(xp):
67+
a = xp.asarray([1, 2])
68+
b = myfunc(a) # This is jitted when xp=jax.numpy
69+
c = mymodule.myfunc(a) # This is not
70+
71+
See Also
72+
--------
73+
patch_lazy_xp_functions
74+
jax.jit: https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
75+
"""
76+
if jax_jit:
77+
func._lazy_jax_jit_kwargs = { # type: ignore[attr-defined]
78+
"static_argnums": static_argnums,
79+
"static_argnames": static_argnames,
80+
}
81+
82+
83+
def patch_lazy_xp_functions(
84+
xp: ModuleType, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
85+
) -> None:
86+
"""If xp==jax.numpy, search for all functions which have been tagged by
87+
`lazy_xp_function` in the globals of the module that defines the current test
88+
and wrap them with `jax.jit`. Unwrap them at the end of the test.
89+
90+
Parameters
91+
----------
92+
xp: module
93+
Array namespace to be tested
94+
request: pytest.FixtureRequest
95+
Pytest fixture, as acquired by the test itself or by one of its fixtures.
96+
monkeypatch: pytest.MonkeyPatch
97+
Pytest fixture, as acquired by the test itself or by one of its fixtures.
98+
99+
See Also
100+
--------
101+
lazy_xp_function
102+
https://docs.pytest.org/en/stable/reference/reference.html#std-fixture-request
103+
"""
104+
if is_jax(xp):
105+
import jax
106+
107+
globals_ = request.module.__dict__
108+
for name, func in globals_.items():
109+
kwargs = getattr(func, "_lazy_jax_jit_kwargs", None)
110+
if kwargs is not None:
111+
monkeypatch.setitem(globals_, name, jax.jit(func, **kwargs))

scipy/_lib/meson.build

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ python_sources = [
122122
'_elementwise_iterative_method.py',
123123
'_finite_differences.py',
124124
'_gcutils.py',
125+
'_lazy_testing.py',
125126
'_pep440.py',
126127
'_testutils.py',
127128
'_threadsafety.py',
@@ -217,21 +218,32 @@ py3.install_sources(
217218
# `array_api_extra` install to simplify import path;
218219
# should be updated whenever new files are added to `array_api_extra`
219220

221+
py3.install_sources(
222+
[
223+
'array_api_extra/src/array_api_extra/_lib/_utils/__init__.py',
224+
'array_api_extra/src/array_api_extra/_lib/_utils/_compat.py',
225+
'array_api_extra/src/array_api_extra/_lib/_utils/_compat.pyi',
226+
'array_api_extra/src/array_api_extra/_lib/_utils/_helpers.py',
227+
'array_api_extra/src/array_api_extra/_lib/_utils/_typing.py',
228+
229+
],
230+
subdir: 'scipy/_lib/array_api_extra/_lib/_utils',
231+
)
232+
220233
py3.install_sources(
221234
[
222235
'array_api_extra/src/array_api_extra/_lib/__init__.py',
223-
'array_api_extra/src/array_api_extra/_lib/_compat.py',
224-
'array_api_extra/src/array_api_extra/_lib/_compat.pyi',
225-
'array_api_extra/src/array_api_extra/_lib/_utils.py',
226-
'array_api_extra/src/array_api_extra/_lib/_typing.py',
236+
'array_api_extra/src/array_api_extra/_lib/_backends.py',
237+
'array_api_extra/src/array_api_extra/_lib/_funcs.py',
238+
'array_api_extra/src/array_api_extra/_lib/_testing.py',
227239
],
228240
subdir: 'scipy/_lib/array_api_extra/_lib',
229241
)
230242

231243
py3.install_sources(
232244
[
233245
'array_api_extra/src/array_api_extra/__init__.py',
234-
'array_api_extra/src/array_api_extra/_funcs.py',
246+
'array_api_extra/src/array_api_extra/_delegation.py',
235247
],
236248
subdir: 'scipy/_lib/array_api_extra',
237249
)

scipy/_lib/tests/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ python_sources = [
1111
'test_deprecation.py',
1212
'test_doccer.py',
1313
'test_import_cycles.py',
14+
'test_lazy_testing.py',
1415
'test_public_api.py',
1516
'test_scipy_version.py',
1617
'test_tmpdirs.py',

scipy/_lib/tests/test__util.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from scipy._lib._array_api import (xp_assert_equal, xp_assert_close, is_numpy,
1616
is_array_api_strict)
17+
from scipy._lib._lazy_testing import lazy_xp_function
1718
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
1819
getfullargspec_no_self, FullArgSpec,
1920
rng_integers, _validate_int, _rename_parameter,
@@ -23,6 +24,11 @@
2324

2425
skip_xp_backends = pytest.mark.skip_xp_backends
2526

27+
lazy_xp_function(_contains_nan, static_argnames=("nan_policy", "xp_omit_okay", "xp"))
28+
# FIXME @jax.jit fails: complex bool mask
29+
lazy_xp_function(_lazywhere, jax_jit=False, static_argnames=("f", "f2"))
30+
31+
2632
@pytest.mark.slow
2733
def test__aligned_zeros():
2834
niter = 10
@@ -344,6 +350,7 @@ def test_contains_nan_with_strings(self):
344350
data4 = np.array([["1", 2], [3, np.nan]], dtype='object')
345351
assert _contains_nan(data4)
346352

353+
@pytest.mark.skip_xp_backends("jax.numpy", reason="lazy backends tested separately")
347354
@pytest.mark.parametrize("nan_policy", ['propagate', 'omit', 'raise'])
348355
def test_array_api(self, xp, nan_policy):
349356
rng = np.random.default_rng(932347235892482)
@@ -359,9 +366,40 @@ def test_array_api(self, xp, nan_policy):
359366
elif nan_policy == 'omit' and not is_numpy(xp):
360367
with pytest.raises(ValueError, match="nan_policy='omit' is incompatible"):
361368
_contains_nan(x, nan_policy)
369+
assert _contains_nan(x, nan_policy, xp_omit_okay=True)
362370
elif nan_policy == 'propagate':
363371
assert _contains_nan(x, nan_policy)
364372

373+
@pytest.mark.skip_xp_backends("numpy", reason="lazy backends only")
374+
@pytest.mark.skip_xp_backends("cupy", reason="lazy backends only")
375+
@pytest.mark.skip_xp_backends("array_api_strict", reason="lazy backends only")
376+
@pytest.mark.skip_xp_backends("torch", reason="lazy backends only")
377+
def test_array_api_lazy(self, xp):
378+
rng = np.random.default_rng(932347235892482)
379+
x0 = rng.random(size=(2, 3, 4))
380+
x = xp.asarray(x0)
381+
382+
xp_assert_equal(_contains_nan(x), xp.asarray(False))
383+
xp_assert_equal(_contains_nan(x, "propagate"), xp.asarray(False))
384+
xp_assert_equal(_contains_nan(x, "omit", xp_omit_okay=True), xp.asarray(False))
385+
# Lazy arrays don't support "omit" and "raise" policies
386+
# TODO test that we're emitting a user-friendly error message.
387+
# Blocked by https://github.com/data-apis/array-api-compat/pull/228
388+
with pytest.raises(TypeError):
389+
_contains_nan(x, "omit")
390+
with pytest.raises(TypeError):
391+
_contains_nan(x, "raise")
392+
393+
x = xpx.at(x)[1, 2, 1].set(np.nan)
394+
395+
xp_assert_equal(_contains_nan(x), xp.asarray(True))
396+
xp_assert_equal(_contains_nan(x, "propagate"), xp.asarray(True))
397+
xp_assert_equal(_contains_nan(x, "omit", xp_omit_okay=True), xp.asarray(True))
398+
with pytest.raises(TypeError):
399+
_contains_nan(x, "omit")
400+
with pytest.raises(TypeError):
401+
_contains_nan(x, "raise")
402+
365403

366404
def test__rng_html_rewrite():
367405
def mock_str():

scipy/_lib/tests/test_array_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
)
88
from scipy._lib import array_api_extra as xpx
99
from scipy._lib._array_api_no_0d import xp_assert_equal as xp_assert_equal_no_0d
10+
from scipy._lib._lazy_testing import lazy_xp_function
11+
12+
lazy_xp_function(_asarray, static_argnames=(
13+
"dtype", "order", "copy", "xp", "check_finite", "subok"))
14+
lazy_xp_function(xp_copy, static_argnames=("xp", ))
1015

1116
skip_xp_backends = pytest.mark.skip_xp_backends
1217

scipy/_lib/tests/test_lazy_testing.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import pytest
2+
from scipy._lib._array_api import array_namespace, is_jax, xp_assert_equal
3+
from scipy._lib._lazy_testing import lazy_xp_function
4+
5+
6+
def jittable(x):
7+
"""A jittable function"""
8+
return x * 2.0
9+
10+
11+
def non_jittable(x):
12+
"""This function materializes the input array, so it will fail
13+
when wrapped in jax.jit
14+
"""
15+
xp = array_namespace(x)
16+
if xp.any(x < 0.0):
17+
raise ValueError("Negative values not allowed")
18+
return x
19+
20+
21+
def non_jittable2(x):
22+
return non_jittable(x)
23+
24+
25+
def static_params(x, n, flag=False):
26+
"""Function with static parameters that must not be jitted"""
27+
if flag and n > 0: # This fails if n or flag are jitted arrays
28+
return x * 2.0
29+
else:
30+
return x * 3.0
31+
32+
33+
def static_params1(x, n, flag=False):
34+
return static_params(x, n, flag)
35+
36+
37+
def static_params2(x, n, flag=False):
38+
return static_params(x, n, flag)
39+
40+
41+
def static_params3(x, n, flag=False):
42+
return static_params(x, n, flag)
43+
44+
45+
lazy_xp_function(jittable)
46+
lazy_xp_function(non_jittable2)
47+
lazy_xp_function(static_params1, static_argnums=(1, 2))
48+
lazy_xp_function(static_params2, static_argnames=("n", "flag"))
49+
lazy_xp_function(static_params3, static_argnums=1, static_argnames="flag")
50+
51+
52+
def test_lazy_xp_function(xp):
53+
x = xp.asarray([1.0, 2.0])
54+
55+
xp_assert_equal(jittable(x), xp.asarray([2.0, 4.0]))
56+
57+
xp_assert_equal(non_jittable(x), xp.asarray([1.0, 2.0])) # Not jitted
58+
if is_jax(xp):
59+
with pytest.raises(
60+
TypeError, match="Attempted boolean conversion of traced array"
61+
):
62+
non_jittable2(x) # Jitted
63+
else:
64+
xp_assert_equal(non_jittable2(x), xp.asarray([1.0, 2.0]))
65+
66+
67+
@pytest.mark.parametrize("func", [static_params1, static_params2, static_params3])
68+
def test_lazy_xp_function_static_params(xp, func):
69+
x = xp.asarray([1.0, 2.0])
70+
xp_assert_equal(func(x, 1), xp.asarray([3.0, 6.0]))
71+
xp_assert_equal(func(x, 1, True), xp.asarray([2.0, 4.0]))
72+
xp_assert_equal(func(x, 1, False), xp.asarray([3.0, 6.0]))
73+
xp_assert_equal(func(x, 0, False), xp.asarray([3.0, 6.0]))
74+
xp_assert_equal(func(x, 1, flag=True), xp.asarray([2.0, 4.0]))
75+
xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0]))

0 commit comments

Comments
 (0)