Skip to content

Commit 121f9c9

Browse files
committed
Address more comments
1 parent d33febb commit 121f9c9

File tree

13 files changed

+128
-93
lines changed

13 files changed

+128
-93
lines changed

scipy/fft/tests/test_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ def test_definition(self, xp):
518518
x2 = xp.asarray([0, 1, 2, 3, 4, -5, -4, -3, -2, -1], dtype=xp.float64)
519519

520520
# default dtype varies across backends
521+
521522
y = 9 * fft.fftfreq(9, xp=xp)
522523
xp_assert_close(y, x, check_dtype=False, check_namespace=True)
523524

@@ -549,6 +550,7 @@ def test_definition(self, xp):
549550
x2 = xp.asarray([0, 1, 2, 3, 4, 5], dtype=xp.float64)
550551

551552
# default dtype varies across backends
553+
552554
y = 9 * fft.rfftfreq(9, xp=xp)
553555
xp_assert_close(y, x, check_dtype=False, check_namespace=True)
554556

scipy/fft/tests/test_real_transforms.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from scipy.fft import dct, idct, dctn, idctn, dst, idst, dstn, idstn
77
import scipy.fft as fft
88
from scipy import fftpack
9-
from scipy._lib._array_api import xp_assert_close
9+
from scipy._lib._array_api import xp_copy, xp_assert_close
1010

1111
skip_xp_backends = pytest.mark.skip_xp_backends
1212

@@ -195,10 +195,9 @@ def test_orthogonalize_noop(func, type, norm, xp):
195195
@skip_xp_backends(cpu_only=True)
196196
@pytest.mark.parametrize("norm", ["backward", "ortho", "forward"])
197197
def test_orthogonalize_dct1(norm, xp):
198-
x_np = np.random.rand(100)
199-
x = xp.asarray(x_np)
198+
x = xp.asarray(np.random.rand(100))
200199

201-
x2 = xp.asarray(x_np.copy())
200+
x2 = xp_copy(x, xp=xp)
202201
x2[0] *= SQRT_2
203202
x2[-1] *= SQRT_2
204203

@@ -230,9 +229,8 @@ def test_orthogonalize_dcst2(func, norm, xp):
230229
@pytest.mark.parametrize("norm", ["backward", "ortho", "forward"])
231230
@pytest.mark.parametrize("func", [dct, dst])
232231
def test_orthogonalize_dcst3(func, norm, xp):
233-
x_np = np.random.rand(100)
234-
x = xp.asarray(x_np)
235-
x2 = xp.asarray(x_np.copy())
232+
x = xp.asarray(np.random.rand(100))
233+
x2 = xp_copy(x, xp=xp)
236234
x2[0 if func == dct else -1] *= SQRT_2
237235

238236
y1 = func(x, type=3, norm=norm, orthogonalize=True)

scipy/ndimage/_morphology.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from . import _nd_image
3737
from . import _filters
3838

39-
from scipy._lib._array_api import is_dask, array_namespace
39+
from scipy._lib.array_api_compat import is_dask_array
4040

4141
__all__ = ['iterate_structure', 'generate_binary_structure', 'binary_erosion',
4242
'binary_dilation', 'binary_opening', 'binary_closing',
@@ -222,7 +222,7 @@ def _binary_erosion(input, structure, iterations, mask, output,
222222
except TypeError as e:
223223
raise TypeError('iterations parameter should be an integer') from e
224224

225-
if is_dask(array_namespace(input)):
225+
if is_dask_array(input):
226226
# Note: If you create an dask array with ones
227227
# it does a stride trick where it makes an array
228228
# (with stride 0) using a scalar
@@ -1802,6 +1802,7 @@ def morphological_laplace(input, size=None, footprint=None, structure=None,
18021802
"""
18031803
tmp1 = grey_dilation(input, size, footprint, structure, None, mode,
18041804
cval, origin, axes=axes)
1805+
input = np.asarray(input)
18051806
if isinstance(output, np.ndarray):
18061807
grey_erosion(input, size, footprint, structure, output, mode,
18071808
cval, origin, axes=axes)
@@ -1812,7 +1813,6 @@ def morphological_laplace(input, size=None, footprint=None, structure=None,
18121813
tmp2 = grey_erosion(input, size, footprint, structure, None, mode,
18131814
cval, origin, axes=axes)
18141815
np.add(tmp1, tmp2, tmp2)
1815-
input = np.asarray(input)
18161816
np.subtract(tmp2, input, tmp2)
18171817
np.subtract(tmp2, input, tmp2)
18181818
return tmp2

scipy/ndimage/tests/test_filters.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def test_correlate01(self, xp):
191191

192192
@xfail_xp_backends('cupy', reason="Differs by a factor of two?")
193193
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
194-
@skip_xp_backends("dask.array", reason="wrong answer")
194+
@xfail_xp_backends("dask.array", reason="wrong answer")
195195
def test_correlate01_overlap(self, xp):
196196
array = xp.reshape(xp.arange(256), (16, 16))
197197
weights = xp.asarray([2])
@@ -881,7 +881,7 @@ def test_gauss06(self, xp):
881881
assert_array_almost_equal(output1, output2)
882882

883883
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
884-
@skip_xp_backends("dask.array", reason="wrong result")
884+
@xfail_xp_backends("dask.array", reason="output keyword not handled properly")
885885
def test_gauss_memory_overlap(self, xp):
886886
input = xp.arange(100 * 100, dtype=xp.float32)
887887
input = xp.reshape(input, (100, 100))
@@ -1228,7 +1228,7 @@ def test_prewitt01(self, dtype, xp):
12281228
assert_array_almost_equal(t, output)
12291229

12301230
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
1231-
@skip_xp_backends("dask.array", reason="output array is read-only.")
1231+
@xfail_xp_backends("dask.array", reason="output array is read-only.")
12321232
@pytest.mark.parametrize('dtype', types + complex_types)
12331233
def test_prewitt02(self, dtype, xp):
12341234
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1838,9 +1838,7 @@ def test_rank06(self, xp):
18381838
@skip_xp_backends("jax.numpy",
18391839
reason="assignment destination is read-only",
18401840
)
1841-
@xfail_xp_backends("dask.array",
1842-
reason="wrong answer",
1843-
)
1841+
@xfail_xp_backends("dask.array", reason="wrong answer")
18441842
def test_rank06_overlap(self, xp):
18451843
array = xp.asarray([[3, 2, 5, 1, 4],
18461844
[5, 8, 3, 7, 1],
@@ -2647,7 +2645,7 @@ def test_gaussian_radius_invalid(xp):
26472645

26482646

26492647
@skip_xp_backends("jax.numpy", reason="output array is read-only")
2650-
@skip_xp_backends("dask.array", reason="wrong answer")
2648+
@xfail_xp_backends("dask.array", reason="wrong answer")
26512649
class TestThreading:
26522650
def check_func_thread(self, n, fun, args, out):
26532651
from threading import Thread

scipy/ndimage/tests/test_measurements.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,17 +548,16 @@ def test_value_indices02(xp):
548548
ndimage.value_indices(data)
549549

550550

551-
@skip_xp_backends("dask.array", reason="len on data-dependent output shapes")
552551
def test_value_indices03(xp):
553552
"Test different input array shapes, from 1-D to 4-D"
554553
for shape in [(36,), (18, 2), (3, 3, 4), (3, 3, 2, 2)]:
555554
a = xp.asarray((12*[1]+12*[2]+12*[3]), dtype=xp.int32)
556555
a = xp.reshape(a, shape)
557556

558-
trueKeys = xp.unique_values(a)
557+
# convert to numpy to prevent issues with data-dependent shapes
558+
# from unique for dask
559+
trueKeys = np.asarray(xp.unique_values(a))
559560
vi = ndimage.value_indices(a)
560-
# TODO: list(trueKeys) needs len of trueKeys
561-
# (which is unknown for dask since it is the result of an unique call)
562561
assert list(vi.keys()) == list(trueKeys)
563562
for k in [int(x) for x in trueKeys]:
564563
trueNdx = xp.nonzero(a == k)
@@ -819,7 +818,6 @@ def test_maximum05(xp):
819818
assert ndimage.maximum(x) == -1
820819

821820

822-
@pytest.mark.filterwarnings("ignore::FutureWarning:dask")
823821
def test_median01(xp):
824822
a = xp.asarray([[1, 2, 0, 1],
825823
[5, 3, 0, 4],

scipy/ndimage/tests/test_morphology.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2315,7 +2315,6 @@ def test_grey_erosion01(self, xp):
23152315
[5, 5, 3, 3, 1]]))
23162316

23172317
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2318-
@skip_xp_backends("dask.array", reason="output array is read-only.")
23192318
@xfail_xp_backends("cupy", reason="https://github.com/cupy/cupy/issues/8398")
23202319
def test_grey_erosion01_overlap(self, xp):
23212320

@@ -2521,6 +2520,8 @@ def test_white_tophat01(self, xp):
25212520
tmp = ndimage.grey_opening(array, footprint=footprint,
25222521
structure=structure)
25232522
expected = array - tmp
2523+
# array created by xp.zeros is non-writeable for dask
2524+
# and jax
25242525
output = xp.zeros(array.shape, dtype=array.dtype)
25252526
ndimage.white_tophat(array, footprint=footprint,
25262527
structure=structure, output=output)
@@ -2574,6 +2575,8 @@ def test_white_tophat04(self, xp):
25742575
structure = xp.asarray(structure)
25752576

25762577
# Check that type mismatch is properly handled
2578+
# This output array is read-only for dask and jax
2579+
# TODO: investigate why for dask?
25772580
output = xp.empty_like(array, dtype=xp.float64)
25782581
ndimage.white_tophat(array, structure=structure, output=output)
25792582

@@ -2588,6 +2591,7 @@ def test_black_tophat01(self, xp):
25882591
tmp = ndimage.grey_closing(array, footprint=footprint,
25892592
structure=structure)
25902593
expected = tmp - array
2594+
# This output array is read-only for dask and jax
25912595
output = xp.zeros(array.shape, dtype=array.dtype)
25922596
ndimage.black_tophat(array, footprint=footprint,
25932597
structure=structure, output=output)

scipy/signal/_filter_design.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,10 @@ def normalize(b, a):
17911791
raise ValueError("Denominator must have at least on nonzero element.")
17921792

17931793
# Trim leading zeros in denominator, leave at least one.
1794-
den = np.trim_zeros(den, 'f')
1794+
1795+
# cast to numpy by hand to avoid libraries like dask
1796+
# trying to dispatch this function via NEP 18
1797+
den = np.trim_zeros(np.asarray(den), 'f')
17951798

17961799
# Normalize transfer function
17971800
num, den = num / den[0], den / den[0]

scipy/signal/_signaltools.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4087,7 +4087,10 @@ def detrend(data: np.ndarray, axis: int = -1,
40874087
else:
40884088
dshape = data.shape
40894089
N = dshape[axis]
4090-
bp = np.sort(np.unique(np.concatenate(np.atleast_1d(0, bp, N))))
4090+
# Manually cast to numpy to prevent
4091+
# NEP18 dispatching for libraries like dask
4092+
bp = np.asarray(np.concatenate(np.atleast_1d(0, bp, N)))
4093+
bp = np.sort(np.unique(bp))
40914094
if np.any(bp > N):
40924095
raise ValueError("Breakpoints must be less than length "
40934096
"of data along given axis.")

scipy/signal/tests/test_signaltools.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ def test_dtype_deprecation(self, xp):
291291
convolve(a, b)
292292

293293

294-
@pytest.mark.filterwarnings("ignore::FutureWarning:dask")
295294
@skip_xp_backends(cpu_only=True, exceptions=['cupy'])
296295
class TestConvolve2d:
297296

@@ -477,8 +476,12 @@ def test_consistency_convolve_funcs(self, xp):
477476
a = xp.arange(5)
478477
b = xp.asarray([3.2, 1.4, 3])
479478
for mode in ['full', 'valid', 'same']:
480-
xp_assert_close(xp.asarray(np.convolve(a, b, mode=mode)),
481-
signal.convolve(a, b, mode=mode))
479+
# cast to numpy when calling np.convolve
480+
# to prevent NEP18 dispatching e.g. from dask
481+
xp_assert_close(
482+
xp.asarray(np.convolve(np.asarray(a), np.asarray(b), mode=mode)),
483+
signal.convolve(a, b, mode=mode)
484+
)
482485
xp_assert_close(
483486
xp.squeeze(
484487
signal.convolve2d(xp.asarray([a]), xp.asarray([b]), mode=mode),
@@ -508,7 +511,6 @@ def test_large_array(self, xp):
508511
assert fails[0].size == 0
509512

510513

511-
@pytest.mark.filterwarnings("ignore::FutureWarning:dask")
512514
@skip_xp_backends(cpu_only=True, exceptions=['cupy'])
513515
class TestFFTConvolve:
514516

@@ -845,7 +847,9 @@ def test_random_data(self, axes, xp):
845847
np.random.seed(1234)
846848
a = xp.asarray(np.random.rand(1233) + 1j * np.random.rand(1233))
847849
b = xp.asarray(np.random.rand(1321) + 1j * np.random.rand(1321))
848-
expected = xp.asarray(np.convolve(a, b, 'full'))
850+
# cast to numpy before np.convolve
851+
# to prevent NEP 18 dispatching for e.g. dask
852+
expected = xp.asarray(np.convolve(np.asarray(a), np.asarray(b), 'full'))
849853

850854
if axes == '':
851855
out = fftconvolve(a, b, 'full')
@@ -860,7 +864,9 @@ def test_random_data_axes(self, axes, xp):
860864
np.random.seed(1234)
861865
a = xp.asarray(np.random.rand(1233) + 1j * np.random.rand(1233))
862866
b = xp.asarray(np.random.rand(1321) + 1j * np.random.rand(1321))
863-
expected = xp.asarray(np.convolve(a, b, 'full'))
867+
# cast to numpy before np.convolve
868+
# to prevent NEP 18 dispatching for e.g. dask
869+
expected = xp.asarray(np.convolve(np.asarray(a), np.asarray(b), 'full'))
864870

865871
a = xp.asarray(np.tile(a, [2, 1]))
866872
b = xp.asarray(np.tile(b, [2, 1]))
@@ -913,7 +919,9 @@ def test_random_data_multidim_axes(self, axes, xp):
913919
def test_many_sizes(self, n, xp):
914920
a = xp.asarray(np.random.rand(n) + 1j * np.random.rand(n))
915921
b = xp.asarray(np.random.rand(n) + 1j * np.random.rand(n))
916-
expected = xp.asarray(np.convolve(a, b, 'full'))
922+
# cast to numpy before np.convolve
923+
# to prevent NEP 18 dispatching for e.g. dask
924+
expected = xp.asarray(np.convolve(np.asarray(a), np.asarray(b), 'full'))
917925

918926
out = fftconvolve(a, b, 'full')
919927
xp_assert_close(out, expected, atol=1e-10)
@@ -965,10 +973,7 @@ def gen_oa_shapes_eq(sizes):
965973

966974

967975
@skip_xp_backends("jax.numpy", reason="fails all around")
968-
@skip_xp_backends("dask.array",
969-
reason="Gets converted to numpy at some point for some reason. "
970-
"Probably also suffers from boolean indexing issues"
971-
)
976+
@skip_xp_backends("dask.array", reason="wrong answer")
972977
class TestOAConvolve:
973978
@pytest.mark.slow()
974979
@pytest.mark.parametrize('shape_a_0, shape_b_0',
@@ -2452,7 +2457,6 @@ def decimal(self, dt, xp):
24522457
dt = np.cdouble
24532458

24542459
# emulate np.finfo(dt).precision for complex64 and complex128
2455-
# note: unwrapped dask has no finfo
24562460
prec = {64: 15, 32: 6}[xp.finfo(dt).bits]
24572461
return int(2 * prec / 3)
24582462

@@ -4018,7 +4022,6 @@ def test_nonnumeric_dtypes(func, xp):
40184022
class TestSOSFilt:
40194023

40204024
# The test_rank* tests are pulled from _TestLinearFilter
4021-
@pytest.mark.filterwarnings("ignore::FutureWarning") # for dask
40224025
@skip_xp_backends('jax.numpy', reason='buffer array is read-only')
40234026
def test_rank1(self, dt, xp):
40244027
dt = getattr(xp, dt)
@@ -4050,7 +4053,6 @@ def test_rank1(self, dt, xp):
40504053
y = sosfilt(sos, x)
40514054
xp_assert_close(y, xp.asarray([1.0, 2, 2, 2, 2, 2, 2, 2]))
40524055

4053-
@pytest.mark.filterwarnings("ignore::FutureWarning") # for dask
40544056
@skip_xp_backends('jax.numpy', reason='buffer array is read-only')
40554057
def test_rank2(self, dt, xp):
40564058
dt = getattr(xp, dt)
@@ -4078,7 +4080,6 @@ def test_rank2(self, dt, xp):
40784080
y = sosfilt(sos, x, axis=1)
40794081
assert_array_almost_equal(y_r2_a1, y)
40804082

4081-
@pytest.mark.filterwarnings("ignore::FutureWarning") # for dask
40824083
@skip_xp_backends('jax.numpy', reason='buffer array is read-only')
40834084
def test_rank3(self, dt, xp):
40844085
dt = getattr(xp, dt)
@@ -4334,7 +4335,6 @@ def test_bp(self, xp):
43344335
with assert_raises(ValueError):
43354336
detrend(data, type="linear", bp=3)
43364337

4337-
@pytest.mark.filterwarnings("ignore::FutureWarning:dask")
43384338
@pytest.mark.parametrize('bp', [np.array([0, 2]), [0, 2]])
43394339
def test_detrend_array_bp(self, bp, xp):
43404340
# regression test for https://github.com/scipy/scipy/issues/18675

scipy/signal/tests/test_windows.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from scipy.fft import fft
1010
from scipy.signal import windows, get_window, resample
1111
from scipy._lib._array_api import (
12-
xp_assert_close, xp_assert_equal, array_namespace, is_dask,
13-
is_torch, is_jax, is_cupy, assert_array_almost_equal, SCIPY_DEVICE,
12+
xp_assert_close, xp_assert_equal, array_namespace, is_torch, is_jax, is_cupy,
13+
assert_array_almost_equal, SCIPY_DEVICE,
1414
)
1515

1616
skip_xp_backends = pytest.mark.skip_xp_backends
@@ -852,16 +852,12 @@ def test_lanczos(self, xp):
852852
get_window('sinc', 6, xp=xp))
853853

854854

855+
@skip_xp_backends("dask.array", reason="https://github.com/dask/dask/issues/2620")
855856
def test_windowfunc_basics(xp):
856857
for window_name, params in window_funcs:
857858
window = getattr(windows, window_name)
858859
if is_jax(xp) and window_name in ['taylor', 'chebwin']:
859860
pytest.skip(reason=f'{window_name = }: item assignment')
860-
if is_dask(xp):
861-
# https://github.com/dask/dask/issues/2620
862-
pytest.skip(
863-
reason="dask doesn't support FFT along axis containing multiple chunks"
864-
)
865861
if window_name in ['dpss']:
866862
if is_cupy(xp):
867863
pytest.skip(reason='dpss window is not implemented for cupy')

0 commit comments

Comments
 (0)