Skip to content

Commit e003b61

Browse files
committed
ENH: array types: add dask.array support
1 parent f40d2cd commit e003b61

File tree

10 files changed

+64
-26
lines changed

10 files changed

+64
-26
lines changed

.github/workflows/array_api.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ jobs:
4343
name: Get commit message
4444
uses: ./.github/workflows/commit_message.yml
4545

46-
pytorch_cpu:
47-
name: Linux PyTorch/JAX/xp-strict CPU
46+
xp_cpu:
47+
name: Linux PyTorch/JAX/Dask/xp-strict CPU
4848
needs: get_commit_message
4949
if: >
5050
needs.get_commit_message.outputs.message == 1
@@ -86,6 +86,10 @@ jobs:
8686
run: |
8787
python -m pip install "jax[cpu]"
8888
89+
- name: Install Dask
90+
run: |
91+
python -m pip install git+https://github.com/dask/dask.git
92+
8993
- name: Prepare compiler cache
9094
id: prep-ccache
9195
shell: bash

dev.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,8 @@ class Test(Task):
710710
multiple=True,
711711
help=(
712712
"Array API backend "
713-
"('all', 'numpy', 'torch', 'cupy', 'array_api_strict', 'jax.numpy')."
713+
"('all', 'numpy', 'torch', 'cupy', 'array_api_strict',"
714+
" 'jax.numpy', 'dask.array')."
714715
)
715716
)
716717
# Argument can't have `help=`; used to consume all of `-- arg1 arg2 arg3`

scipy/_lib/_array_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
is_cupy_namespace as is_cupy,
2525
is_torch_namespace as is_torch,
2626
is_jax_namespace as is_jax,
27+
is_dask_namespace as is_dask,
2728
is_array_api_strict_namespace as is_array_api_strict
2829
)
2930

@@ -246,6 +247,9 @@ def _strict_check(actual, desired, xp, *,
246247
assert actual.dtype == desired.dtype, _msg
247248

248249
if check_shape:
250+
if is_dask(xp):
251+
actual.compute_chunk_sizes()
252+
desired.compute_chunk_sizes()
249253
_msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}"
250254
assert actual.shape == desired.shape, _msg
251255

scipy/_lib/tests/test_array_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from scipy.conftest import array_api_compatible
55
from scipy._lib._array_api import (
66
_GLOBAL_CONFIG, array_namespace, _asarray, xp_copy, xp_assert_equal, is_numpy,
7-
np_compat,
7+
np_compat, is_dask
88
)
99
from scipy._lib import array_api_extra as xpx
1010
from scipy._lib._array_api_no_0d import xp_assert_equal as xp_assert_equal_no_0d
@@ -77,6 +77,10 @@ def test_copy(self, xp):
7777
x[1] = 11
7878
x[2] = 12
7979

80+
if is_dask(xp):
81+
x.compute()
82+
y.compute()
83+
8084
assert x[0] != y[0]
8185
assert x[1] != y[1]
8286
assert x[2] != y[2]

scipy/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from scipy._lib._fpumode import get_fpu_mode
1414
from scipy._lib._testutils import FPUModeChangeWarning
15-
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE
15+
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE, xp_device
1616
from scipy._lib import _pep440
1717

1818
try:
@@ -178,6 +178,12 @@ def num_parallel_threads():
178178
except ImportError:
179179
pass
180180

181+
try:
182+
import dask.array # type: ignore[import-not-found]
183+
xp_available_backends.update({'dask.array': dask.array})
184+
except ImportError:
185+
pass
186+
181187
# by default, use all available backends
182188
if SCIPY_ARRAY_API.lower() not in ("1", "true"):
183189
SCIPY_ARRAY_API_ = json.loads(SCIPY_ARRAY_API)
@@ -366,6 +372,9 @@ def skip_or_xfail_xp_backends(xp, backends, kwargs, skip_or_xfail='skip'):
366372
for d in xp.empty(0).devices():
367373
if 'cpu' not in d.device_kind:
368374
skip_or_xfail(reason=reason)
375+
elif xp.__name__ == 'dask.array':
376+
if xp_device(xp.empty(0)) != 'cpu':
377+
skip_or_xfail(reason=reason)
369378

370379

371380
# Following the approach of NumPy's conftest.py...

scipy/ndimage/tests/test_filters.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def test_correlate01(self, xp):
195195

196196
@xfail_xp_backends('cupy', reason="Differs by a factor of two?")
197197
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
198+
@skip_xp_backends("dask.array", reason="output array is read-only.")
198199
def test_correlate01_overlap(self, xp):
199200
array = xp.reshape(xp.arange(256), (16, 16))
200201
weights = xp.asarray([2])
@@ -537,6 +538,7 @@ def test_correlate22(self, dtype_array, dtype_output, xp):
537538
assert_array_almost_equal(output, expected)
538539

539540
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
541+
@skip_xp_backends("dask.array", reason="output array is read-only.")
540542
@pytest.mark.parametrize('dtype_array', types)
541543
@pytest.mark.parametrize('dtype_output', types)
542544
def test_correlate23(self, dtype_array, dtype_output, xp):
@@ -556,6 +558,7 @@ def test_correlate23(self, dtype_array, dtype_output, xp):
556558
assert_array_almost_equal(output, expected)
557559

558560
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
561+
@skip_xp_backends("dask.array", reason="output array is read-only.")
559562
@pytest.mark.parametrize('dtype_array', types)
560563
@pytest.mark.parametrize('dtype_output', types)
561564
def test_correlate24(self, dtype_array, dtype_output, xp):
@@ -576,6 +579,7 @@ def test_correlate24(self, dtype_array, dtype_output, xp):
576579
assert_array_almost_equal(output, tcov)
577580

578581
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
582+
@skip_xp_backends("dask.array", reason="output array is read-only.")
579583
@pytest.mark.parametrize('dtype_array', types)
580584
@pytest.mark.parametrize('dtype_output', types)
581585
def test_correlate25(self, dtype_array, dtype_output, xp):
@@ -881,6 +885,7 @@ def test_gauss06(self, xp):
881885
assert_array_almost_equal(output1, output2)
882886

883887
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
888+
@skip_xp_backends("dask.array", reason="output array is read-only.")
884889
def test_gauss_memory_overlap(self, xp):
885890
input = xp.arange(100 * 100, dtype=xp.float32)
886891
input = xp.reshape(input, (100, 100))
@@ -1227,6 +1232,7 @@ def test_prewitt01(self, dtype, xp):
12271232
assert_array_almost_equal(t, output)
12281233

12291234
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
1235+
@skip_xp_backends("dask.array", reason="output array is read-only.")
12301236
@pytest.mark.parametrize('dtype', types + complex_types)
12311237
def test_prewitt02(self, dtype, xp):
12321238
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1289,6 +1295,7 @@ def test_sobel01(self, dtype, xp):
12891295
assert_array_almost_equal(t, output)
12901296

12911297
@skip_xp_backends("jax.numpy", reason="output array is read-only.",)
1298+
@skip_xp_backends("dask.array", reason="output array is read-only.")
12921299
@pytest.mark.parametrize('dtype', types + complex_types)
12931300
def test_sobel02(self, dtype, xp):
12941301
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1349,6 +1356,7 @@ def test_laplace01(self, dtype, xp):
13491356
assert_array_almost_equal(tmp1 + tmp2, output)
13501357

13511358
@skip_xp_backends("jax.numpy", reason="output array is read-only",)
1359+
@skip_xp_backends("dask.array", reason="output array is read-only.")
13521360
@pytest.mark.parametrize('dtype',
13531361
["int32", "float32", "float64",
13541362
"complex64", "complex128"])
@@ -1379,6 +1387,7 @@ def test_gaussian_laplace01(self, dtype, xp):
13791387
assert_array_almost_equal(tmp1 + tmp2, output)
13801388

13811389
@skip_xp_backends("jax.numpy", reason="output array is read-only")
1390+
@skip_xp_backends("dask.array", reason="output array is read-only.")
13821391
@pytest.mark.parametrize('dtype',
13831392
["int32", "float32", "float64",
13841393
"complex64", "complex128"])
@@ -1395,6 +1404,7 @@ def test_gaussian_laplace02(self, dtype, xp):
13951404
assert_array_almost_equal(tmp1 + tmp2, output)
13961405

13971406
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
1407+
@skip_xp_backends("dask.array", reason="output array is read-only.")
13981408
@pytest.mark.parametrize('dtype', types + complex_types)
13991409
def test_generic_laplace01(self, dtype, xp):
14001410
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1420,6 +1430,7 @@ def derivative2(input, axis, output, mode, cval, a, b):
14201430
assert_array_almost_equal(tmp, output)
14211431

14221432
@skip_xp_backends("jax.numpy", reason="output array is read-only")
1433+
@skip_xp_backends("dask.array", reason="output array is read-only.")
14231434
@pytest.mark.parametrize('dtype',
14241435
["int32", "float32", "float64",
14251436
"complex64", "complex128"])
@@ -1441,6 +1452,7 @@ def test_gaussian_gradient_magnitude01(self, dtype, xp):
14411452
xp_assert_close(output, expected, rtol=1e-6, atol=1e-6)
14421453

14431454
@skip_xp_backends("jax.numpy", reason="output array is read-only")
1455+
@skip_xp_backends("dask.array", reason="output array is read-only.")
14441456
@pytest.mark.parametrize('dtype',
14451457
["int32", "float32", "float64",
14461458
"complex64", "complex128"])
@@ -2640,6 +2652,7 @@ def test_gaussian_radius_invalid(xp):
26402652

26412653

26422654
@skip_xp_backends("jax.numpy", reason="output array is read-only")
2655+
@skip_xp_backends("dask.array", reason="output array is read-only.")
26432656
class TestThreading:
26442657
def check_func_thread(self, n, fun, args, out):
26452658
from threading import Thread

scipy/ndimage/tests/test_morphology.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,6 +2303,7 @@ def test_grey_erosion01(self, xp):
23032303
[5, 5, 3, 3, 1]]))
23042304

23052305
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2306+
@skip_xp_backends("dask.array", reason="output array is read-only.")
23062307
@xfail_xp_backends("cupy", reason="https://github.com/cupy/cupy/issues/8398")
23072308
def test_grey_erosion01_overlap(self, xp):
23082309

@@ -2498,6 +2499,7 @@ def test_morphological_laplace02(self, xp):
24982499
assert_array_almost_equal(output, expected)
24992500

25002501
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2502+
@skip_xp_backends("dask.array", reason="output array is read-only.")
25012503
def test_white_tophat01(self, xp):
25022504
array = xp.asarray([[3, 2, 5, 1, 4],
25032505
[7, 6, 9, 3, 5],
@@ -2551,6 +2553,7 @@ def test_white_tophat03(self, xp):
25512553
xp_assert_equal(output, expected)
25522554

25532555
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2556+
@skip_xp_backends("dask.array", reason="output array is read-only.")
25542557
def test_white_tophat04(self, xp):
25552558
array = np.eye(5, dtype=bool)
25562559
structure = np.ones((3, 3), dtype=bool)
@@ -2563,6 +2566,7 @@ def test_white_tophat04(self, xp):
25632566
ndimage.white_tophat(array, structure=structure, output=output)
25642567

25652568
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2569+
@skip_xp_backends("dask.array", reason="output array is read-only.")
25662570
def test_black_tophat01(self, xp):
25672571
array = xp.asarray([[3, 2, 5, 1, 4],
25682572
[7, 6, 9, 3, 5],
@@ -2616,6 +2620,7 @@ def test_black_tophat03(self, xp):
26162620
xp_assert_equal(output, expected)
26172621

26182622
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2623+
@skip_xp_backends("dask.array", reason="output array is read-only.")
26192624
def test_black_tophat04(self, xp):
26202625
array = xp.asarray(np.eye(5, dtype=bool))
26212626
structure = xp.asarray(np.ones((3, 3), dtype=bool))

scipy/special/tests/test_support_alternative_backends.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from scipy.conftest import array_api_compatible
66
from scipy import special
77
from scipy._lib._array_api_no_0d import xp_assert_close
8-
from scipy._lib._array_api import is_jax, is_torch, SCIPY_DEVICE
8+
from scipy._lib._array_api import is_jax, is_torch, SCIPY_DEVICE, is_dask
99
from scipy._lib.array_api_compat import numpy as np
1010

1111
try:
@@ -64,6 +64,9 @@ def test_support_alternative_backends(xp, f_name_n_args, dtype, shapes):
6464
):
6565
pytest.skip(f"`{f_name}` does not have an array-agnostic implementation "
6666
f"and cannot delegate to PyTorch.")
67+
68+
if is_dask(xp) and f_name == 'rel_entr':
69+
pytest.skip("boolean index assignment")
6770

6871
shapes = shapes[:n_args]
6972
f = getattr(special, f_name)

scipy/stats/tests/test_entropy.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
from scipy._lib._array_api_no_0d import (xp_assert_close, xp_assert_equal,
1212
xp_assert_less)
1313

14+
@pytest.mark.skip_xp_backends("dask.array", reason="boolean index assignment")
15+
@pytest.mark.usefixtures("skip_xp_backends")
16+
@array_api_compatible
1417
class TestEntropy:
15-
@array_api_compatible
18+
1619
def test_entropy_positive(self, xp):
1720
# See ticket #497
1821
pk = xp.asarray([0.5, 0.2, 0.3])
@@ -22,7 +25,6 @@ def test_entropy_positive(self, xp):
2225
xp_assert_equal(eself, xp.asarray(0.))
2326
xp_assert_less(-edouble, xp.asarray(0.))
2427

25-
@array_api_compatible
2628
def test_entropy_base(self, xp):
2729
pk = xp.ones(16)
2830
S = stats.entropy(pk, base=2.)
@@ -34,21 +36,18 @@ def test_entropy_base(self, xp):
3436
S2 = stats.entropy(pk, qk, base=2.)
3537
xp_assert_less(xp.abs(S/S2 - math.log(2.)), xp.asarray(1.e-5))
3638

37-
@array_api_compatible
3839
def test_entropy_zero(self, xp):
3940
# Test for PR-479
4041
x = xp.asarray([0., 1., 2.])
4142
xp_assert_close(stats.entropy(x),
4243
xp.asarray(0.63651416829481278))
4344

44-
@array_api_compatible
4545
def test_entropy_2d(self, xp):
4646
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
4747
qk = xp.asarray([[0.2, 0.1], [0.3, 0.6], [0.5, 0.3]])
4848
xp_assert_close(stats.entropy(pk, qk),
4949
xp.asarray([0.1933259, 0.18609809]))
5050

51-
@array_api_compatible
5251
def test_entropy_2d_zero(self, xp):
5352
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
5453
qk = xp.asarray([[0.0, 0.1], [0.3, 0.6], [0.5, 0.3]])
@@ -59,54 +58,46 @@ def test_entropy_2d_zero(self, xp):
5958
xp_assert_close(stats.entropy(pk, qk),
6059
xp.asarray([0.17403988, 0.18609809]))
6160

62-
@array_api_compatible
6361
def test_entropy_base_2d_nondefault_axis(self, xp):
6462
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
6563
xp_assert_close(stats.entropy(pk, axis=1),
6664
xp.asarray([0.63651417, 0.63651417, 0.66156324]))
6765

68-
@array_api_compatible
6966
def test_entropy_2d_nondefault_axis(self, xp):
7067
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
7168
qk = xp.asarray([[0.2, 0.1], [0.3, 0.6], [0.5, 0.3]])
7269
xp_assert_close(stats.entropy(pk, qk, axis=1),
7370
xp.asarray([0.23104906, 0.23104906, 0.12770641]))
7471

75-
@array_api_compatible
7672
def test_entropy_raises_value_error(self, xp):
7773
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
7874
qk = xp.asarray([[0.1, 0.2], [0.6, 0.3]])
7975
message = "Array shapes are incompatible for broadcasting."
8076
with pytest.raises(ValueError, match=message):
8177
stats.entropy(pk, qk)
8278

83-
@array_api_compatible
8479
def test_base_entropy_with_axis_0_is_equal_to_default(self, xp):
8580
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
8681
xp_assert_close(stats.entropy(pk, axis=0),
8782
stats.entropy(pk))
8883

89-
@array_api_compatible
9084
def test_entropy_with_axis_0_is_equal_to_default(self, xp):
9185
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
9286
qk = xp.asarray([[0.2, 0.1], [0.3, 0.6], [0.5, 0.3]])
9387
xp_assert_close(stats.entropy(pk, qk, axis=0),
9488
stats.entropy(pk, qk))
9589

96-
@array_api_compatible
9790
def test_base_entropy_transposed(self, xp):
9891
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
9992
xp_assert_close(stats.entropy(pk.T),
10093
stats.entropy(pk, axis=1))
10194

102-
@array_api_compatible
10395
def test_entropy_transposed(self, xp):
10496
pk = xp.asarray([[0.1, 0.2], [0.6, 0.3], [0.3, 0.5]])
10597
qk = xp.asarray([[0.2, 0.1], [0.3, 0.6], [0.5, 0.3]])
10698
xp_assert_close(stats.entropy(pk.T, qk.T),
10799
stats.entropy(pk, qk, axis=1))
108100

109-
@array_api_compatible
110101
def test_entropy_broadcasting(self, xp):
111102
rng = np.random.default_rng(74187315492831452)
112103
x = xp.asarray(rng.random(3))
@@ -115,22 +106,21 @@ def test_entropy_broadcasting(self, xp):
115106
xp_assert_equal(res[0], stats.entropy(x, y[0, ...]))
116107
xp_assert_equal(res[1], stats.entropy(x, y[1, ...]))
117108

118-
@array_api_compatible
119109
def test_entropy_shape_mismatch(self, xp):
120110
x = xp.ones((10, 1, 12))
121111
y = xp.ones((11, 2))
122112
message = "Array shapes are incompatible for broadcasting."
123113
with pytest.raises(ValueError, match=message):
124114
stats.entropy(x, y)
125115

126-
@array_api_compatible
127116
def test_input_validation(self, xp):
128117
x = xp.ones(10)
129118
message = "`base` must be a positive number."
130119
with pytest.raises(ValueError, match=message):
131120
stats.entropy(x, base=-2)
132121

133122

123+
@pytest.mark.skip_xp_backends("dask.array", reason="No sorting in Dask")
134124
@array_api_compatible
135125
@pytest.mark.usefixtures("skip_xp_backends")
136126
class TestDifferentialEntropy:

0 commit comments

Comments
 (0)