Skip to content

Commit 4b14378

Browse files
Address more comments (#22)
* Address more comments * Address more comments Co-authored-by: Guido Imperiale <crusaderky@gmail.com> * update comment * Update scipy/ndimage/tests/test_morphology.py Co-authored-by: Guido Imperiale <crusaderky@gmail.com> --------- Co-authored-by: Guido Imperiale <crusaderky@gmail.com>
1 parent d33febb commit 4b14378

16 files changed

+144
-114
lines changed

scipy/_lib/tests/test_array_api.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,6 @@ def test_array_api_extra_hook(self):
6666
with pytest.raises(TypeError, match=msg):
6767
xpx.atleast_nd("abc", ndim=0)
6868

69-
@skip_xp_backends(
70-
"dask.array",
71-
reason="raw dask.array namespace doesn't ignores copy=True in asarray"
72-
)
7369
def test_copy(self, xp):
7470
for _xp in [xp, None]:
7571
x = xp.asarray([1, 2, 3])

scipy/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,9 +399,10 @@ def skip_or_xfail_xp_backends(request: pytest.FixtureRequest,
399399
if 'cpu' not in d.device_kind:
400400
skip_or_xfail(reason=reason)
401401
elif xp.__name__ == 'dask.array' and 'dask.array' not in exceptions:
402-
if xp_device(xp.empty(0)) != 'cpu':
403-
skip_or_xfail(reason=reason)
404-
402+
# dask has no device. 'cpu' is a hack introduced by array-api-compat.
403+
# Force to revisit this when in the future
404+
# dask adds proper device support
405+
assert xp_device(xp.empty(0)) == 'cpu'
405406

406407
# Following the approach of NumPy's conftest.py...
407408
# Use a known and persistent tmpdir for hypothesis' caches, which

scipy/fft/tests/test_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ def test_definition(self, xp):
518518
x2 = xp.asarray([0, 1, 2, 3, 4, -5, -4, -3, -2, -1], dtype=xp.float64)
519519

520520
# default dtype varies across backends
521+
521522
y = 9 * fft.fftfreq(9, xp=xp)
522523
xp_assert_close(y, x, check_dtype=False, check_namespace=True)
523524

@@ -549,6 +550,7 @@ def test_definition(self, xp):
549550
x2 = xp.asarray([0, 1, 2, 3, 4, 5], dtype=xp.float64)
550551

551552
# default dtype varies across backends
553+
552554
y = 9 * fft.rfftfreq(9, xp=xp)
553555
xp_assert_close(y, x, check_dtype=False, check_namespace=True)
554556

scipy/fft/tests/test_real_transforms.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from scipy.fft import dct, idct, dctn, idctn, dst, idst, dstn, idstn
77
import scipy.fft as fft
88
from scipy import fftpack
9-
from scipy._lib._array_api import xp_assert_close
9+
from scipy._lib._array_api import xp_copy, xp_assert_close
1010

1111
skip_xp_backends = pytest.mark.skip_xp_backends
1212

@@ -195,10 +195,9 @@ def test_orthogonalize_noop(func, type, norm, xp):
195195
@skip_xp_backends(cpu_only=True)
196196
@pytest.mark.parametrize("norm", ["backward", "ortho", "forward"])
197197
def test_orthogonalize_dct1(norm, xp):
198-
x_np = np.random.rand(100)
199-
x = xp.asarray(x_np)
198+
x = xp.asarray(np.random.rand(100))
200199

201-
x2 = xp.asarray(x_np.copy())
200+
x2 = xp_copy(x, xp=xp)
202201
x2[0] *= SQRT_2
203202
x2[-1] *= SQRT_2
204203

@@ -230,9 +229,8 @@ def test_orthogonalize_dcst2(func, norm, xp):
230229
@pytest.mark.parametrize("norm", ["backward", "ortho", "forward"])
231230
@pytest.mark.parametrize("func", [dct, dst])
232231
def test_orthogonalize_dcst3(func, norm, xp):
233-
x_np = np.random.rand(100)
234-
x = xp.asarray(x_np)
235-
x2 = xp.asarray(x_np.copy())
232+
x = xp.asarray(np.random.rand(100))
233+
x2 = xp_copy(x, xp=xp)
236234
x2[0 if func == dct else -1] *= SQRT_2
237235

238236
y1 = func(x, type=3, norm=norm, orthogonalize=True)

scipy/ndimage/_morphology.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from . import _nd_image
3737
from . import _filters
3838

39-
from scipy._lib._array_api import is_dask, array_namespace
39+
from scipy._lib.array_api_compat import is_dask_array
4040

4141
__all__ = ['iterate_structure', 'generate_binary_structure', 'binary_erosion',
4242
'binary_dilation', 'binary_opening', 'binary_closing',
@@ -222,7 +222,7 @@ def _binary_erosion(input, structure, iterations, mask, output,
222222
except TypeError as e:
223223
raise TypeError('iterations parameter should be an integer') from e
224224

225-
if is_dask(array_namespace(input)):
225+
if is_dask_array(input):
226226
# Note: If you create an dask array with ones
227227
# it does a stride trick where it makes an array
228228
# (with stride 0) using a scalar
@@ -1800,6 +1800,7 @@ def morphological_laplace(input, size=None, footprint=None, structure=None,
18001800
Output
18011801
18021802
"""
1803+
input = np.asarray(input)
18031804
tmp1 = grey_dilation(input, size, footprint, structure, None, mode,
18041805
cval, origin, axes=axes)
18051806
if isinstance(output, np.ndarray):
@@ -1812,7 +1813,6 @@ def morphological_laplace(input, size=None, footprint=None, structure=None,
18121813
tmp2 = grey_erosion(input, size, footprint, structure, None, mode,
18131814
cval, origin, axes=axes)
18141815
np.add(tmp1, tmp2, tmp2)
1815-
input = np.asarray(input)
18161816
np.subtract(tmp2, input, tmp2)
18171817
np.subtract(tmp2, input, tmp2)
18181818
return tmp2

scipy/ndimage/tests/test_filters.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def test_correlate01(self, xp):
191191

192192
@xfail_xp_backends('cupy', reason="Differs by a factor of two?")
193193
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
194-
@skip_xp_backends("dask.array", reason="wrong answer")
194+
@xfail_xp_backends("dask.array", reason="wrong answer")
195195
def test_correlate01_overlap(self, xp):
196196
array = xp.reshape(xp.arange(256), (16, 16))
197197
weights = xp.asarray([2])
@@ -534,7 +534,7 @@ def test_correlate22(self, dtype_array, dtype_output, xp):
534534
assert_array_almost_equal(output, expected)
535535

536536
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
537-
@skip_xp_backends("dask.array", reason="output array is read-only.")
537+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
538538
@pytest.mark.parametrize('dtype_array', types)
539539
@pytest.mark.parametrize('dtype_output', types)
540540
def test_correlate23(self, dtype_array, dtype_output, xp):
@@ -554,7 +554,7 @@ def test_correlate23(self, dtype_array, dtype_output, xp):
554554
assert_array_almost_equal(output, expected)
555555

556556
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
557-
@skip_xp_backends("dask.array", reason="output array is read-only.")
557+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
558558
@pytest.mark.parametrize('dtype_array', types)
559559
@pytest.mark.parametrize('dtype_output', types)
560560
def test_correlate24(self, dtype_array, dtype_output, xp):
@@ -575,7 +575,7 @@ def test_correlate24(self, dtype_array, dtype_output, xp):
575575
assert_array_almost_equal(output, tcov)
576576

577577
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
578-
@skip_xp_backends("dask.array", reason="output array is read-only.")
578+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
579579
@pytest.mark.parametrize('dtype_array', types)
580580
@pytest.mark.parametrize('dtype_output', types)
581581
def test_correlate25(self, dtype_array, dtype_output, xp):
@@ -881,7 +881,7 @@ def test_gauss06(self, xp):
881881
assert_array_almost_equal(output1, output2)
882882

883883
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
884-
@skip_xp_backends("dask.array", reason="wrong result")
884+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
885885
def test_gauss_memory_overlap(self, xp):
886886
input = xp.arange(100 * 100, dtype=xp.float32)
887887
input = xp.reshape(input, (100, 100))
@@ -1228,7 +1228,7 @@ def test_prewitt01(self, dtype, xp):
12281228
assert_array_almost_equal(t, output)
12291229

12301230
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
1231-
@skip_xp_backends("dask.array", reason="output array is read-only.")
1231+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
12321232
@pytest.mark.parametrize('dtype', types + complex_types)
12331233
def test_prewitt02(self, dtype, xp):
12341234
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1291,7 +1291,7 @@ def test_sobel01(self, dtype, xp):
12911291
assert_array_almost_equal(t, output)
12921292

12931293
@skip_xp_backends("jax.numpy", reason="output array is read-only.",)
1294-
@skip_xp_backends("dask.array", reason="output array is read-only.")
1294+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
12951295
@pytest.mark.parametrize('dtype', types + complex_types)
12961296
def test_sobel02(self, dtype, xp):
12971297
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1352,7 +1352,7 @@ def test_laplace01(self, dtype, xp):
13521352
assert_array_almost_equal(tmp1 + tmp2, output)
13531353

13541354
@skip_xp_backends("jax.numpy", reason="output array is read-only",)
1355-
@skip_xp_backends("dask.array", reason="output array is read-only.")
1355+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
13561356
@pytest.mark.parametrize('dtype',
13571357
["int32", "float32", "float64",
13581358
"complex64", "complex128"])
@@ -1383,7 +1383,7 @@ def test_gaussian_laplace01(self, dtype, xp):
13831383
assert_array_almost_equal(tmp1 + tmp2, output)
13841384

13851385
@skip_xp_backends("jax.numpy", reason="output array is read-only")
1386-
@skip_xp_backends("dask.array", reason="output array is read-only.")
1386+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
13871387
@pytest.mark.parametrize('dtype',
13881388
["int32", "float32", "float64",
13891389
"complex64", "complex128"])
@@ -1400,7 +1400,7 @@ def test_gaussian_laplace02(self, dtype, xp):
14001400
assert_array_almost_equal(tmp1 + tmp2, output)
14011401

14021402
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
1403-
@skip_xp_backends("dask.array", reason="output array is read-only.")
1403+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
14041404
@pytest.mark.parametrize('dtype', types + complex_types)
14051405
def test_generic_laplace01(self, dtype, xp):
14061406
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1426,7 +1426,7 @@ def derivative2(input, axis, output, mode, cval, a, b):
14261426
assert_array_almost_equal(tmp, output)
14271427

14281428
@skip_xp_backends("jax.numpy", reason="output array is read-only")
1429-
@skip_xp_backends("dask.array", reason="output array is read-only.")
1429+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
14301430
@pytest.mark.parametrize('dtype',
14311431
["int32", "float32", "float64",
14321432
"complex64", "complex128"])
@@ -1447,7 +1447,7 @@ def test_gaussian_gradient_magnitude01(self, dtype, xp):
14471447
xp_assert_close(output, expected, rtol=1e-6, atol=1e-6)
14481448

14491449
@skip_xp_backends("jax.numpy", reason="output array is read-only")
1450-
@skip_xp_backends("dask.array", reason="output array is read-only.")
1450+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
14511451
@pytest.mark.parametrize('dtype',
14521452
["int32", "float32", "float64",
14531453
"complex64", "complex128"])
@@ -1838,9 +1838,7 @@ def test_rank06(self, xp):
18381838
@skip_xp_backends("jax.numpy",
18391839
reason="assignment destination is read-only",
18401840
)
1841-
@xfail_xp_backends("dask.array",
1842-
reason="wrong answer",
1843-
)
1841+
@xfail_xp_backends("dask.array", reason="wrong answer")
18441842
def test_rank06_overlap(self, xp):
18451843
array = xp.asarray([[3, 2, 5, 1, 4],
18461844
[5, 8, 3, 7, 1],
@@ -2647,7 +2645,7 @@ def test_gaussian_radius_invalid(xp):
26472645

26482646

26492647
@skip_xp_backends("jax.numpy", reason="output array is read-only")
2650-
@skip_xp_backends("dask.array", reason="wrong answer")
2648+
@xfail_xp_backends("dask.array", reason="wrong answer")
26512649
class TestThreading:
26522650
def check_func_thread(self, n, fun, args, out):
26532651
from threading import Thread

scipy/ndimage/tests/test_measurements.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,17 +548,16 @@ def test_value_indices02(xp):
548548
ndimage.value_indices(data)
549549

550550

551-
@skip_xp_backends("dask.array", reason="len on data-dependent output shapes")
552551
def test_value_indices03(xp):
553552
"Test different input array shapes, from 1-D to 4-D"
554553
for shape in [(36,), (18, 2), (3, 3, 4), (3, 3, 2, 2)]:
555554
a = xp.asarray((12*[1]+12*[2]+12*[3]), dtype=xp.int32)
556555
a = xp.reshape(a, shape)
557556

558-
trueKeys = xp.unique_values(a)
557+
# convert to numpy to prevent issues with data-dependent shapes
558+
# from unique for dask
559+
trueKeys = np.asarray(xp.unique_values(a))
559560
vi = ndimage.value_indices(a)
560-
# TODO: list(trueKeys) needs len of trueKeys
561-
# (which is unknown for dask since it is the result of an unique call)
562561
assert list(vi.keys()) == list(trueKeys)
563562
for k in [int(x) for x in trueKeys]:
564563
trueNdx = xp.nonzero(a == k)
@@ -688,6 +687,7 @@ def test_sum_labels(xp):
688687
assert xp.all(output_sum == output_labels)
689688
assert_array_almost_equal(output_labels, xp.asarray([4.0, 0.0, 5.0]))
690689

690+
691691
def test_mean01(xp):
692692
labels = np.asarray([1, 0], dtype=bool)
693693
labels = xp.asarray(labels)
@@ -819,7 +819,6 @@ def test_maximum05(xp):
819819
assert ndimage.maximum(x) == -1
820820

821821

822-
@pytest.mark.filterwarnings("ignore::FutureWarning:dask")
823822
def test_median01(xp):
824823
a = xp.asarray([[1, 2, 0, 1],
825824
[5, 3, 0, 4],
@@ -862,6 +861,7 @@ def test_median_gh12836_bool(xp):
862861
output = ndimage.median(a, labels=xp.ones((2,)), index=xp.asarray([1]))
863862
assert_array_almost_equal(output, xp.asarray([1.0]))
864863

864+
865865
def test_median_no_int_overflow(xp):
866866
# test integer overflow fix on example from gh-12836
867867
a = xp.asarray([65, 70], dtype=xp.int8)
@@ -902,6 +902,7 @@ def test_variance04(xp):
902902
output = ndimage.variance(input)
903903
assert_almost_equal(output, xp.asarray(0.25), check_0d=False)
904904

905+
905906
def test_variance05(xp):
906907
labels = xp.asarray([2, 2, 3])
907908
for type in types:
@@ -911,6 +912,7 @@ def test_variance05(xp):
911912
output = ndimage.variance(input, labels, 2)
912913
assert_almost_equal(output, xp.asarray(1.0), check_0d=False)
913914

915+
914916
def test_variance06(xp):
915917
labels = xp.asarray([2, 2, 3, 3, 4])
916918
with np.errstate(all='ignore'):
@@ -1126,6 +1128,7 @@ def test_maximum_position06(xp):
11261128
assert output[0] == (0, 0)
11271129
assert output[1] == (1, 1)
11281130

1131+
11291132
@xfail_xp_backends("torch", reason="output[1] is wrong on pytorch")
11301133
def test_maximum_position07(xp):
11311134
# Test float labels

scipy/ndimage/tests/test_morphology.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,7 +2315,6 @@ def test_grey_erosion01(self, xp):
23152315
[5, 5, 3, 3, 1]]))
23162316

23172317
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2318-
@skip_xp_backends("dask.array", reason="output array is read-only.")
23192318
@xfail_xp_backends("cupy", reason="https://github.com/cupy/cupy/issues/8398")
23202319
def test_grey_erosion01_overlap(self, xp):
23212320

@@ -2511,7 +2510,7 @@ def test_morphological_laplace02(self, xp):
25112510
assert_array_almost_equal(output, expected)
25122511

25132512
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2514-
@skip_xp_backends("dask.array", reason="output array is read-only.")
2513+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
25152514
def test_white_tophat01(self, xp):
25162515
array = xp.asarray([[3, 2, 5, 1, 4],
25172516
[7, 6, 9, 3, 5],
@@ -2565,7 +2564,7 @@ def test_white_tophat03(self, xp):
25652564
xp_assert_equal(output, expected)
25662565

25672566
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2568-
@skip_xp_backends("dask.array", reason="output array is read-only.")
2567+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
25692568
def test_white_tophat04(self, xp):
25702569
array = np.eye(5, dtype=bool)
25712570
structure = np.ones((3, 3), dtype=bool)
@@ -2578,7 +2577,7 @@ def test_white_tophat04(self, xp):
25782577
ndimage.white_tophat(array, structure=structure, output=output)
25792578

25802579
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2581-
@skip_xp_backends("dask.array", reason="output array is read-only.")
2580+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
25822581
def test_black_tophat01(self, xp):
25832582
array = xp.asarray([[3, 2, 5, 1, 4],
25842583
[7, 6, 9, 3, 5],
@@ -2632,7 +2631,7 @@ def test_black_tophat03(self, xp):
26322631
xp_assert_equal(output, expected)
26332632

26342633
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2635-
@skip_xp_backends("dask.array", reason="output array is read-only.")
2634+
@skip_xp_backends("dask.array", reason="converts dask output array to numpy")
26362635
def test_black_tophat04(self, xp):
26372636
array = xp.asarray(np.eye(5, dtype=bool))
26382637
structure = xp.asarray(np.ones((3, 3), dtype=bool))

scipy/signal/_filter_design.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,9 @@ def normalize(b, a):
17791779
"""
17801780
num, den = b, a
17811781

1782+
# cast to numpy by hand to avoid libraries like dask
1783+
# trying to dispatch this function via NEP 18
1784+
den = np.asarray(den)
17821785
den = np.atleast_1d(den)
17831786
num = np.atleast_2d(_align_nums(num))
17841787

scipy/signal/_signaltools.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4087,6 +4087,9 @@ def detrend(data: np.ndarray, axis: int = -1,
40874087
else:
40884088
dshape = data.shape
40894089
N = dshape[axis]
4090+
# Manually cast to numpy to prevent
4091+
# NEP18 dispatching for libraries like dask
4092+
bp = np.asarray(bp)
40904093
bp = np.sort(np.unique(np.concatenate(np.atleast_1d(0, bp, N))))
40914094
if np.any(bp > N):
40924095
raise ValueError("Breakpoints must be less than length "

0 commit comments

Comments
 (0)