Skip to content

Commit e40dc94

Browse files
authored
more fixes from Thomas (#20)
* address comments and get rest of tests passing * Force cast to numpy in scipy.ndimage * fix binary_erosion * fix sosfilt * clean up * address comments * fix a couple ndimage tests * integrate latest array-api-compat * clean up
1 parent 32b1199 commit e40dc94

File tree

11 files changed

+44
-49
lines changed

11 files changed

+44
-49
lines changed

scipy/_lib/tests/test__util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import scipy._lib.array_api_extra as xpx
2222
from scipy import cluster, interpolate, linalg, optimize, sparse, spatial, stats
2323

24+
skip_xp_backends = pytest.mark.skip_xp_backends
2425

2526
@pytest.mark.slow
2627
def test__aligned_zeros():
@@ -586,6 +587,7 @@ class TestLazywhere:
586587

587588
@pytest.mark.fail_slow(10)
588589
@pytest.mark.filterwarnings('ignore::RuntimeWarning') # overflows, etc.
590+
@skip_xp_backends("dask.array", reason="lazywhere doesn't work with dask")
589591
@given(n_arrays=n_arrays, rng_seed=rng_seed, dtype=dtype, p=p, data=data)
590592
@pytest.mark.thread_unsafe
591593
def test_basic(self, n_arrays, rng_seed, dtype, p, data, xp):

scipy/_lib/tests/test_array_api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from scipy._lib import array_api_extra as xpx
99
from scipy._lib._array_api_no_0d import xp_assert_equal as xp_assert_equal_no_0d
1010

11+
skip_xp_backends = pytest.mark.skip_xp_backends
12+
1113

1214
@pytest.mark.skipif(not _GLOBAL_CONFIG["SCIPY_ARRAY_API"],
1315
reason="Array API test; set environment variable SCIPY_ARRAY_API=1 to run it")
@@ -59,6 +61,10 @@ def test_array_api_extra_hook(self):
5961
with pytest.raises(TypeError, match=msg):
6062
xpx.atleast_nd("abc", ndim=0)
6163

64+
@skip_xp_backends(
65+
"dask.array",
66+
reason="raw dask.array namespace doesn't ignores copy=True in asarray"
67+
)
6268
def test_copy(self, xp):
6369
for _xp in [xp, None]:
6470
x = xp.asarray([1, 2, 3])

scipy/ndimage/_measurements.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,10 @@ def single_group(vals):
630630
if labels is None:
631631
return single_group(input)
632632

633+
# manually cast to numpy
634+
# since libaries (e.g. dask) that implement __array_function__
635+
# will not return a numpy array from broadcast_arrays
636+
labels = np.asarray(labels)
633637
# ensure input and labels match sizes
634638
input, labels = np.broadcast_arrays(input, labels)
635639

@@ -944,6 +948,10 @@ def single_group(vals, positions):
944948
if labels is None:
945949
return single_group(input, positions)
946950

951+
# manually cast to numpy
952+
# since libaries (e.g. dask) that implement __array_function__
953+
# will not return a numpy array from broadcast_arrays
954+
labels = np.asarray(labels)
947955
# ensure input and labels match sizes
948956
input, labels = np.broadcast_arrays(input, labels)
949957

scipy/ndimage/_morphology.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from . import _nd_image
3737
from . import _filters
3838

39+
from scipy._lib._array_api import is_dask, array_namespace
40+
3941
__all__ = ['iterate_structure', 'generate_binary_structure', 'binary_erosion',
4042
'binary_dilation', 'binary_opening', 'binary_closing',
4143
'binary_hit_or_miss', 'binary_propagation', 'binary_fill_holes',
@@ -220,7 +222,14 @@ def _binary_erosion(input, structure, iterations, mask, output,
220222
except TypeError as e:
221223
raise TypeError('iterations parameter should be an integer') from e
222224

223-
input = np.asarray(input)
225+
if is_dask(array_namespace(input)):
226+
# Note: If you create an dask array with ones
227+
# it does a stride trick where it makes an array
228+
# (with stride 0) using a scalar
229+
# this messes up the C ndimage iteration code
230+
input = np.asarray(input, order="C")
231+
else:
232+
input = np.asarray(input)
224233
ndim = input.ndim
225234
if np.iscomplexobj(input):
226235
raise TypeError('Complex type not supported')

scipy/ndimage/tests/test_filters.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def test_correlate01(self, xp):
191191

192192
@xfail_xp_backends('cupy', reason="Differs by a factor of two?")
193193
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
194-
@skip_xp_backends("dask.array", reason="output array is read-only.")
194+
@skip_xp_backends("dask.array", reason="wrong answer")
195195
def test_correlate01_overlap(self, xp):
196196
array = xp.reshape(xp.arange(256), (16, 16))
197197
weights = xp.asarray([2])
@@ -881,7 +881,7 @@ def test_gauss06(self, xp):
881881
assert_array_almost_equal(output1, output2)
882882

883883
@skip_xp_backends("jax.numpy", reason="output array is read-only.")
884-
@skip_xp_backends("dask.array", reason="output array is read-only.")
884+
@skip_xp_backends("dask.array", reason="wrong result")
885885
def test_gauss_memory_overlap(self, xp):
886886
input = xp.arange(100 * 100, dtype=xp.float32)
887887
input = xp.reshape(input, (100, 100))
@@ -2647,7 +2647,6 @@ def test_gaussian_radius_invalid(xp):
26472647

26482648

26492649
@skip_xp_backends("jax.numpy", reason="output array is read-only")
2650-
@skip_xp_backends("dask.array", reason="output array is read-only.")
26512650
class TestThreading:
26522651
def check_func_thread(self, n, fun, args, out):
26532652
from threading import Thread

scipy/ndimage/tests/test_measurements.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,14 @@ def test_basic(self, xp):
149149

150150

151151
def test_label01(xp):
152-
data = xp.ones([])
152+
data = xp.ones(())
153153
out, n = ndimage.label(data)
154154
assert out == 1
155155
assert n == 1
156156

157157

158158
def test_label02(xp):
159-
data = xp.zeros([])
159+
data = xp.zeros(())
160160
out, n = ndimage.label(data)
161161
assert out == 0
162162
assert n == 0
@@ -557,6 +557,8 @@ def test_value_indices03(xp):
557557

558558
trueKeys = xp.unique_values(a)
559559
vi = ndimage.value_indices(a)
560+
# TODO: list(trueKeys) needs len of trueKeys
561+
# (which is unknown for dask since it is the result of an unique call)
560562
assert list(vi.keys()) == list(trueKeys)
561563
for k in [int(x) for x in trueKeys]:
562564
trueNdx = xp.nonzero(a == k)
@@ -665,7 +667,6 @@ def test_sum11(xp):
665667
assert_almost_equal(output, xp.asarray(6.0), check_0d=False)
666668

667669

668-
@skip_xp_backends("dask.array", reason="data-dependent output shapes")
669670
def test_sum12(xp):
670671
labels = xp.asarray([[1, 2], [2, 4]], dtype=xp.int8)
671672
for type in types:
@@ -675,7 +676,6 @@ def test_sum12(xp):
675676
assert_array_almost_equal(output, xp.asarray([4.0, 0.0, 5.0]))
676677

677678

678-
@skip_xp_backends("dask.array", reason="data-dependent output shapes")
679679
def test_sum_labels(xp):
680680
labels = xp.asarray([[1, 2], [2, 4]], dtype=xp.int8)
681681
for type in types:
@@ -688,7 +688,6 @@ def test_sum_labels(xp):
688688
assert xp.all(output_sum == output_labels)
689689
assert_array_almost_equal(output_labels, xp.asarray([4.0, 0.0, 5.0]))
690690

691-
@xfail_xp_backends("dask.array", reason="dask outputs wrong results here")
692691
def test_mean01(xp):
693692
labels = np.asarray([1, 0], dtype=bool)
694693
labels = xp.asarray(labels)
@@ -699,7 +698,6 @@ def test_mean01(xp):
699698
assert_almost_equal(output, xp.asarray(2.0), check_0d=False)
700699

701700

702-
@xfail_xp_backends("dask.array", reason="dask outputs wrong results here")
703701
def test_mean02(xp):
704702
labels = np.asarray([1, 0], dtype=bool)
705703
input = np.asarray([[1, 2], [3, 4]], dtype=bool)
@@ -710,7 +708,6 @@ def test_mean02(xp):
710708
assert_almost_equal(output, xp.asarray(1.0), check_0d=False)
711709

712710

713-
@xfail_xp_backends("dask.array", reason="dask outputs wrong results here")
714711
def test_mean03(xp):
715712
labels = xp.asarray([1, 2])
716713
for type in types:
@@ -721,7 +718,6 @@ def test_mean03(xp):
721718
assert_almost_equal(output, xp.asarray(3.0), check_0d=False)
722719

723720

724-
@xfail_xp_backends("dask.array", reason="dask outputs wrong results here")
725721
def test_mean04(xp):
726722
labels = xp.asarray([[1, 2], [2, 4]], dtype=xp.int8)
727723
with np.errstate(all='ignore'):
@@ -768,7 +764,6 @@ def test_minimum03(xp):
768764
assert_almost_equal(output, xp.asarray(2.0), check_0d=False)
769765

770766

771-
@skip_xp_backends('dask.array', reason="no argsort in Dask")
772767
def test_minimum04(xp):
773768
labels = xp.asarray([[1, 2], [2, 3]])
774769
for type in types:
@@ -808,7 +803,6 @@ def test_maximum03(xp):
808803
assert_almost_equal(output, xp.asarray(4.0), check_0d=False)
809804

810805

811-
@skip_xp_backends('dask.array', reason="no argsort in Dask")
812806
def test_maximum04(xp):
813807
labels = xp.asarray([[1, 2], [2, 3]])
814808
for type in types:
@@ -848,8 +842,6 @@ def test_median02(xp):
848842
assert_almost_equal(output, xp.asarray(1.0), check_0d=False)
849843

850844

851-
@skip_xp_backends("dask.array",
852-
reason="dask.array.median only implemented for along an axis.")
853845
def test_median03(xp):
854846
a = xp.asarray([[1, 2, 0, 1],
855847
[5, 3, 0, 4],
@@ -863,15 +855,13 @@ def test_median03(xp):
863855
assert_almost_equal(output, xp.asarray(3.0), check_0d=False)
864856

865857

866-
@xfail_xp_backends("dask.array", reason="Crash inside dask searchsorted")
867858
def test_median_gh12836_bool(xp):
868859
# test boolean addition fix on example from gh-12836
869860
a = np.asarray([1, 1], dtype=bool)
870861
a = xp.asarray(a)
871862
output = ndimage.median(a, labels=xp.ones((2,)), index=xp.asarray([1]))
872863
assert_array_almost_equal(output, xp.asarray([1.0]))
873864

874-
@xfail_xp_backends("dask.array", reason="Crash inside dask searchsorted")
875865
def test_median_no_int_overflow(xp):
876866
# test integer overflow fix on example from gh-12836
877867
a = xp.asarray([65, 70], dtype=xp.int8)
@@ -912,10 +902,6 @@ def test_variance04(xp):
912902
output = ndimage.variance(input)
913903
assert_almost_equal(output, xp.asarray(0.25), check_0d=False)
914904

915-
# dask.array is maybe due to failed conversion to numpy?
916-
# array-api-strict should've caught use of non array API functions I think
917-
@skip_xp_backends("dask.array",
918-
reason="conjugate called on dask.array which doesn't exist")
919905
def test_variance05(xp):
920906
labels = xp.asarray([2, 2, 3])
921907
for type in types:
@@ -925,7 +911,6 @@ def test_variance05(xp):
925911
output = ndimage.variance(input, labels, 2)
926912
assert_almost_equal(output, xp.asarray(1.0), check_0d=False)
927913

928-
@skip_xp_backends("dask.array", reason="Data-dependent output shapes")
929914
def test_variance06(xp):
930915
labels = xp.asarray([2, 2, 3, 3, 4])
931916
with np.errstate(all='ignore'):
@@ -970,10 +955,6 @@ def test_standard_deviation04(xp):
970955
assert_almost_equal(output, xp.asarray(0.5), check_0d=False)
971956

972957

973-
# dask.array is maybe due to failed conversion to numpy?
974-
# array-api-strict should've caught use of non array API functions I think
975-
@skip_xp_backends("dask.array",
976-
reason="conjugate called on dask.array which doesn't exist")
977958
def test_standard_deviation05(xp):
978959
labels = xp.asarray([2, 2, 3])
979960
for type in types:
@@ -983,7 +964,6 @@ def test_standard_deviation05(xp):
983964
assert_almost_equal(output, xp.asarray(1.0), check_0d=False)
984965

985966

986-
@skip_xp_backends("dask.array", reason="data-dependent output shapes")
987967
def test_standard_deviation06(xp):
988968
labels = xp.asarray([2, 2, 3, 3, 4])
989969
with np.errstate(all='ignore'):
@@ -996,7 +976,6 @@ def test_standard_deviation06(xp):
996976
assert_array_almost_equal(output, xp.asarray([1.0, 1.0, 0.0]))
997977

998978

999-
@skip_xp_backends("dask.array", reason="data-dependent output shapes")
1000979
def test_standard_deviation07(xp):
1001980
labels = xp.asarray([1])
1002981
with np.errstate(all='ignore'):
@@ -1070,7 +1049,6 @@ def test_minimum_position06(xp):
10701049
assert output == (0, 1)
10711050

10721051

1073-
@skip_xp_backends('dask.array', reason="no argsort in Dask")
10741052
def test_minimum_position07(xp):
10751053
labels = xp.asarray([1, 2, 3, 4])
10761054
for type in types:
@@ -1136,7 +1114,6 @@ def test_maximum_position05(xp):
11361114
assert output == (0, 0)
11371115

11381116

1139-
@skip_xp_backends('dask.array', reason="no argsort in Dask")
11401117
def test_maximum_position06(xp):
11411118
labels = xp.asarray([1, 2, 0, 4])
11421119
for type in types:
@@ -1149,7 +1126,6 @@ def test_maximum_position06(xp):
11491126
assert output[0] == (0, 0)
11501127
assert output[1] == (1, 1)
11511128

1152-
@xfail_xp_backends("dask.array", reason="crash in dask.array searchsorted")
11531129
@xfail_xp_backends("torch", reason="output[1] is wrong on pytorch")
11541130
def test_maximum_position07(xp):
11551131
# Test float labels
@@ -1165,7 +1141,6 @@ def test_maximum_position07(xp):
11651141
assert output[1] == (0, 3)
11661142

11671143

1168-
@xfail_xp_backends("dask.array", reason="dask wrong answer")
11691144
def test_extrema01(xp):
11701145
labels = np.asarray([1, 0], dtype=bool)
11711146
labels = xp.asarray(labels)
@@ -1182,7 +1157,6 @@ def test_extrema01(xp):
11821157
assert output1 == (output2, output3, output4, output5)
11831158

11841159

1185-
@xfail_xp_backends("dask.array", reason="dask wrong answer")
11861160
def test_extrema02(xp):
11871161
labels = xp.asarray([1, 2])
11881162
for type in types:
@@ -1201,7 +1175,6 @@ def test_extrema02(xp):
12011175
assert output1 == (output2, output3, output4, output5)
12021176

12031177

1204-
@skip_xp_backends('dask.array', reason="no argsort in Dask")
12051178
def test_extrema03(xp):
12061179
labels = xp.asarray([[1, 2], [2, 3]])
12071180
for type in types:
@@ -1230,7 +1203,6 @@ def test_extrema03(xp):
12301203
assert output1[3] == output5
12311204

12321205

1233-
@skip_xp_backends('dask.array', reason="no argsort in Dask")
12341206
def test_extrema04(xp):
12351207
labels = xp.asarray([1, 2, 0, 4])
12361208
for type in types:
@@ -1307,7 +1279,6 @@ def test_center_of_mass06(xp):
13071279
assert output == expected
13081280

13091281

1310-
@xfail_xp_backends("dask.array", reason="wrong output shape")
13111282
def test_center_of_mass07(xp):
13121283
labels = xp.asarray([1, 0])
13131284
expected = (0.5, 0.0)
@@ -1317,7 +1288,6 @@ def test_center_of_mass07(xp):
13171288
assert output == expected
13181289

13191290

1320-
@xfail_xp_backends("dask.array", reason="wrong output shape")
13211291
def test_center_of_mass08(xp):
13221292
labels = xp.asarray([1, 2])
13231293
expected = (0.5, 1.0)
@@ -1327,7 +1297,6 @@ def test_center_of_mass08(xp):
13271297
assert output == expected
13281298

13291299

1330-
@skip_xp_backends("dask.array", reason="data-dependent output shapes")
13311300
def test_center_of_mass09(xp):
13321301
labels = xp.asarray((1, 2))
13331302
expected = xp.asarray([(0.5, 0.0), (0.5, 1.0)], dtype=xp.float64)
@@ -1365,7 +1334,6 @@ def test_histogram03(xp):
13651334
assert_array_almost_equal(output[1], expected2)
13661335

13671336

1368-
@skip_xp_backends("dask.array", reason="data-dependent output shapes")
13691337
def test_stat_funcs_2d(xp):
13701338
a = xp.asarray([[5, 6, 0, 0, 0], [8, 9, 0, 0, 0], [0, 0, 0, 3, 5]])
13711339
lbl = xp.asarray([[1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [0, 0, 0, 2, 2]])

scipy/ndimage/tests/test_morphology.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@
1515
pytestmark = [skip_xp_backends(cpu_only=True, exceptions=['cupy', 'jax.numpy'])]
1616

1717

18-
@xfail_xp_backends('dask.array',
19-
reason="Dask.array gets wrong results here. "
20-
"Some tests can pass when creating input array from list of ones"
21-
"instead of xp.ones, so maybe something is getting corrupted here."
22-
)
2318
class TestNdimageMorphology:
2419

2520
@xfail_xp_backends('cupy', reason='CuPy does not have distance_transform_bf.')

scipy/signal/_filter_design.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,11 @@ def group_delay(system, w=512, whole=False, fs=2*pi):
727727

728728
def _validate_sos(sos):
729729
"""Helper to validate a SOS input"""
730+
# manually cast to numpy array
731+
# since libs like dask implement __array_function__
732+
# (and will return a dask array instead of casting to
733+
# ndarray in atleast_2d)
734+
sos = np.asarray(sos)
730735
sos = np.atleast_2d(sos)
731736
if sos.ndim != 2:
732737
raise ValueError('sos array must be 2D')

0 commit comments

Comments
 (0)