Skip to content

Commit 3df386d

Browse files
committed
Address more comments
1 parent d33febb commit 3df386d

File tree

7 files changed

+24
-30
lines changed

7 files changed

+24
-30
lines changed

scipy/fft/tests/test_real_transforms.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from scipy.fft import dct, idct, dctn, idctn, dst, idst, dstn, idstn
77
import scipy.fft as fft
88
from scipy import fftpack
9-
from scipy._lib._array_api import xp_assert_close
9+
from scipy._lib._array_api import xp_copy, xp_assert_close
1010

1111
skip_xp_backends = pytest.mark.skip_xp_backends
1212

@@ -195,10 +195,9 @@ def test_orthogonalize_noop(func, type, norm, xp):
195195
@skip_xp_backends(cpu_only=True)
196196
@pytest.mark.parametrize("norm", ["backward", "ortho", "forward"])
197197
def test_orthogonalize_dct1(norm, xp):
198-
x_np = np.random.rand(100)
199-
x = xp.asarray(x_np)
198+
x = xp.asarray(np.random.rand(100))
200199

201-
x2 = xp.asarray(x_np.copy())
200+
x2 = xp_copy(x, xp=xp)
202201
x2[0] *= SQRT_2
203202
x2[-1] *= SQRT_2
204203

@@ -230,9 +229,8 @@ def test_orthogonalize_dcst2(func, norm, xp):
230229
@pytest.mark.parametrize("norm", ["backward", "ortho", "forward"])
231230
@pytest.mark.parametrize("func", [dct, dst])
232231
def test_orthogonalize_dcst3(func, norm, xp):
233-
x_np = np.random.rand(100)
234-
x = xp.asarray(x_np)
235-
x2 = xp.asarray(x_np.copy())
232+
x = xp.asarray(np.random.rand(100))
233+
x2 = xp_copy(x, xp=xp)
236234
x2[0 if func == dct else -1] *= SQRT_2
237235

238236
y1 = func(x, type=3, norm=norm, orthogonalize=True)

scipy/ndimage/_morphology.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from . import _nd_image
3737
from . import _filters
3838

39-
from scipy._lib._array_api import is_dask, array_namespace
39+
from scipy._lib.array_api_compat.array_api_compat import is_dask_array
4040

4141
__all__ = ['iterate_structure', 'generate_binary_structure', 'binary_erosion',
4242
'binary_dilation', 'binary_opening', 'binary_closing',
@@ -222,7 +222,7 @@ def _binary_erosion(input, structure, iterations, mask, output,
222222
except TypeError as e:
223223
raise TypeError('iterations parameter should be an integer') from e
224224

225-
if is_dask(array_namespace(input)):
225+
if is_dask_array(input):
226226
# Note: If you create an dask array with ones
227227
# it does a stride trick where it makes an array
228228
# (with stride 0) using a scalar
@@ -1802,6 +1802,7 @@ def morphological_laplace(input, size=None, footprint=None, structure=None,
18021802
"""
18031803
tmp1 = grey_dilation(input, size, footprint, structure, None, mode,
18041804
cval, origin, axes=axes)
1805+
input = np.asarray(input)
18051806
if isinstance(output, np.ndarray):
18061807
grey_erosion(input, size, footprint, structure, output, mode,
18071808
cval, origin, axes=axes)
@@ -1812,7 +1813,6 @@ def morphological_laplace(input, size=None, footprint=None, structure=None,
18121813
tmp2 = grey_erosion(input, size, footprint, structure, None, mode,
18131814
cval, origin, axes=axes)
18141815
np.add(tmp1, tmp2, tmp2)
1815-
input = np.asarray(input)
18161816
np.subtract(tmp2, input, tmp2)
18171817
np.subtract(tmp2, input, tmp2)
18181818
return tmp2

scipy/ndimage/tests/test_filters.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def test_correlate01(self, xp):
191191

192192
@xfail_xp_backends('cupy', reason="Differs by a factor of two?")
193193
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
194-
@skip_xp_backends("dask.array", reason="wrong answer")
194+
@xfail_xp_backends("dask.array", reason="wrong answer")
195195
def test_correlate01_overlap(self, xp):
196196
array = xp.reshape(xp.arange(256), (16, 16))
197197
weights = xp.asarray([2])
@@ -881,7 +881,7 @@ def test_gauss06(self, xp):
881881
assert_array_almost_equal(output1, output2)
882882

883883
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
884-
@skip_xp_backends("dask.array", reason="wrong result")
884+
@xfail_xp_backends("dask.array", reason="output keyword not handled properly")
885885
def test_gauss_memory_overlap(self, xp):
886886
input = xp.arange(100 * 100, dtype=xp.float32)
887887
input = xp.reshape(input, (100, 100))
@@ -1228,7 +1228,7 @@ def test_prewitt01(self, dtype, xp):
12281228
assert_array_almost_equal(t, output)
12291229

12301230
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
1231-
@skip_xp_backends("dask.array", reason="output array is read-only.")
1231+
@xfail_xp_backends("dask.array", reason="output array is read-only.")
12321232
@pytest.mark.parametrize('dtype', types + complex_types)
12331233
def test_prewitt02(self, dtype, xp):
12341234
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
@@ -1838,9 +1838,7 @@ def test_rank06(self, xp):
18381838
@skip_xp_backends("jax.numpy",
18391839
reason="assignment destination is read-only",
18401840
)
1841-
@xfail_xp_backends("dask.array",
1842-
reason="wrong answer",
1843-
)
1841+
@xfail_xp_backends("dask.array", reason="wrong answer")
18441842
def test_rank06_overlap(self, xp):
18451843
array = xp.asarray([[3, 2, 5, 1, 4],
18461844
[5, 8, 3, 7, 1],
@@ -2647,7 +2645,7 @@ def test_gaussian_radius_invalid(xp):
26472645

26482646

26492647
@skip_xp_backends("jax.numpy", reason="output array is read-only")
2650-
@skip_xp_backends("dask.array", reason="wrong answer")
2648+
@xfail_xp_backends("dask.array", reason="wrong answer")
26512649
class TestThreading:
26522650
def check_func_thread(self, n, fun, args, out):
26532651
from threading import Thread

scipy/ndimage/tests/test_measurements.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -548,17 +548,16 @@ def test_value_indices02(xp):
548548
ndimage.value_indices(data)
549549

550550

551-
@skip_xp_backends("dask.array", reason="len on data-dependent output shapes")
552551
def test_value_indices03(xp):
553552
"Test different input array shapes, from 1-D to 4-D"
554553
for shape in [(36,), (18, 2), (3, 3, 4), (3, 3, 2, 2)]:
555554
a = xp.asarray((12*[1]+12*[2]+12*[3]), dtype=xp.int32)
556555
a = xp.reshape(a, shape)
557556

558-
trueKeys = xp.unique_values(a)
557+
# convert to numpy to prevent issues with data-dependent shapes
558+
# from unique for dask
559+
trueKeys = np.asarray(xp.unique_values(a))
559560
vi = ndimage.value_indices(a)
560-
# TODO: list(trueKeys) needs len of trueKeys
561-
# (which is unknown for dask since it is the result of an unique call)
562561
assert list(vi.keys()) == list(trueKeys)
563562
for k in [int(x) for x in trueKeys]:
564563
trueNdx = xp.nonzero(a == k)

scipy/ndimage/tests/test_morphology.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2315,7 +2315,6 @@ def test_grey_erosion01(self, xp):
23152315
[5, 5, 3, 3, 1]]))
23162316

23172317
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
2318-
@skip_xp_backends("dask.array", reason="output array is read-only.")
23192318
@xfail_xp_backends("cupy", reason="https://github.com/cupy/cupy/issues/8398")
23202319
def test_grey_erosion01_overlap(self, xp):
23212320

@@ -2521,6 +2520,8 @@ def test_white_tophat01(self, xp):
25212520
tmp = ndimage.grey_opening(array, footprint=footprint,
25222521
structure=structure)
25232522
expected = array - tmp
2523+
# array created by xp.zeros is non-writeable for dask
2524+
# and jax
25242525
output = xp.zeros(array.shape, dtype=array.dtype)
25252526
ndimage.white_tophat(array, footprint=footprint,
25262527
structure=structure, output=output)
@@ -2574,6 +2575,8 @@ def test_white_tophat04(self, xp):
25742575
structure = xp.asarray(structure)
25752576

25762577
# Check that type mismatch is properly handled
2578+
# This output array is read-only for dask and jax
2579+
# TODO: investigate why for dask?
25772580
output = xp.empty_like(array, dtype=xp.float64)
25782581
ndimage.white_tophat(array, structure=structure, output=output)
25792582

@@ -2588,6 +2591,7 @@ def test_black_tophat01(self, xp):
25882591
tmp = ndimage.grey_closing(array, footprint=footprint,
25892592
structure=structure)
25902593
expected = tmp - array
2594+
# This output array is read-only for dask and jax
25912595
output = xp.zeros(array.shape, dtype=array.dtype)
25922596
ndimage.black_tophat(array, footprint=footprint,
25932597
structure=structure, output=output)

scipy/signal/tests/test_signaltools.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2452,7 +2452,6 @@ def decimal(self, dt, xp):
24522452
dt = np.cdouble
24532453

24542454
# emulate np.finfo(dt).precision for complex64 and complex128
2455-
# note: unwrapped dask has no finfo
24562455
prec = {64: 15, 32: 6}[xp.finfo(dt).bits]
24572456
return int(2 * prec / 3)
24582457

scipy/signal/tests/test_windows.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from scipy.fft import fft
1010
from scipy.signal import windows, get_window, resample
1111
from scipy._lib._array_api import (
12-
xp_assert_close, xp_assert_equal, array_namespace, is_dask,
13-
is_torch, is_jax, is_cupy, assert_array_almost_equal, SCIPY_DEVICE,
12+
xp_assert_close, xp_assert_equal, array_namespace, is_torch, is_jax, is_cupy,
13+
assert_array_almost_equal, SCIPY_DEVICE,
1414
)
1515

1616
skip_xp_backends = pytest.mark.skip_xp_backends
@@ -852,16 +852,12 @@ def test_lanczos(self, xp):
852852
get_window('sinc', 6, xp=xp))
853853

854854

855+
@skip_xp_backends("dask.array", reason="https://github.com/dask/dask/issues/2620")
855856
def test_windowfunc_basics(xp):
856857
for window_name, params in window_funcs:
857858
window = getattr(windows, window_name)
858859
if is_jax(xp) and window_name in ['taylor', 'chebwin']:
859860
pytest.skip(reason=f'{window_name = }: item assignment')
860-
if is_dask(xp):
861-
# https://github.com/dask/dask/issues/2620
862-
pytest.skip(
863-
reason="dask doesn't support FFT along axis containing multiple chunks"
864-
)
865861
if window_name in ['dpss']:
866862
if is_cupy(xp):
867863
pytest.skip(reason='dpss window is not implemented for cupy')

0 commit comments

Comments
 (0)