Skip to content

Commit 7b21d32

Browse files
authored
TST: replace pytest.xfail with skip/xfail_xp_backends (scipy#22311)
1 parent 90ad156 commit 7b21d32

File tree

2 files changed

+51
-28
lines changed

2 files changed

+51
-28
lines changed

scipy/ndimage/tests/test_measurements.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from numpy.testing import suppress_warnings
66

77
from scipy._lib._array_api import (
8-
is_jax,
98
is_torch,
109
array_namespace,
1110
xp_assert_equal,
@@ -22,6 +21,7 @@
2221
from . import types
2322

2423
skip_xp_backends = pytest.mark.skip_xp_backends
24+
xfail_xp_backends = pytest.mark.xfail_xp_backends
2525
pytestmark = [skip_xp_backends(cpu_only=True, exceptions=['cupy', 'jax.numpy'])]
2626

2727
IS_WINDOWS_AND_NP1 = os.name == 'nt' and np.__version__ < '2'
@@ -364,10 +364,8 @@ def test_label_output_dtype(xp):
364364
assert output.dtype == t
365365

366366

367+
@skip_xp_backends("jax.numpy", reason="JAX does not raise")
367368
def test_label_output_wrong_size(xp):
368-
if is_jax(xp):
369-
pytest.xfail("JAX does not raise")
370-
371369
data = xp.ones([5])
372370
for t in types:
373371
dtype = getattr(xp, t)
@@ -1136,11 +1134,9 @@ def test_maximum_position06(xp):
11361134
assert output[1] == (1, 1)
11371135

11381136

1137+
@xfail_xp_backends("torch", reason="output[1] is wrong on pytorch")
11391138
def test_maximum_position07(xp):
11401139
# Test float labels
1141-
if is_torch(xp):
1142-
pytest.xfail("output[1] is wrong on pytorch")
1143-
11441140
labels = xp.asarray([1.0, 2.5, 0.0, 4.5])
11451141
for type in types:
11461142
dtype = getattr(xp, type)

scipy/stats/tests/test_entropy.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
from scipy import stats
88
from scipy.stats import norm, expon # type: ignore[attr-defined]
9-
from scipy._lib._array_api import array_namespace, is_array_api_strict, is_jax
9+
from scipy._lib._array_api import array_namespace
1010
from scipy._lib._array_api_no_0d import (xp_assert_close, xp_assert_equal,
1111
xp_assert_less)
1212

13+
skip_xp_backends = pytest.mark.skip_xp_backends
14+
1315
class TestEntropy:
1416
def test_entropy_positive(self, xp):
1517
# See ticket #497
@@ -224,13 +226,21 @@ def test_input_validation(self, xp):
224226
with pytest.raises(ValueError, match=message):
225227
stats.differential_entropy(x, method='ekki-ekki')
226228

227-
@pytest.mark.parametrize('method', ['vasicek', 'van es',
228-
'ebrahimi', 'correa'])
229+
@pytest.mark.parametrize('method', [
230+
'vasicek',
231+
'van es',
232+
pytest.param(
233+
'ebrahimi',
234+
marks=skip_xp_backends("jax.numpy",
235+
reason="JAX doesn't support item assignment")
236+
),
237+
pytest.param(
238+
'correa',
239+
marks=skip_xp_backends("array_api_strict",
240+
reason="Needs fancy indexing.")
241+
)
242+
])
229243
def test_consistency(self, method, xp):
230-
if is_jax(xp) and method == 'ebrahimi':
231-
pytest.xfail("Needs array assignment.")
232-
elif is_array_api_strict(xp) and method == 'correa':
233-
pytest.xfail("Needs fancy indexing.")
234244
# test that method is a consistent estimator
235245
n = 10000 if method == 'correa' else 1000000
236246
rvs = stats.norm.rvs(size=n, random_state=0)
@@ -258,17 +268,25 @@ def test_consistency(self, method, xp):
258268
rmse_std_cases = {norm: norm_rmse_std_cases,
259269
expon: expon_rmse_std_cases}
260270

261-
@pytest.mark.parametrize('method', ['vasicek', 'van es', 'ebrahimi', 'correa'])
271+
@pytest.mark.parametrize('method', [
272+
'vasicek',
273+
'van es',
274+
pytest.param(
275+
'ebrahimi',
276+
marks=skip_xp_backends("jax.numpy",
277+
reason="JAX doesn't support item assignment")
278+
),
279+
pytest.param(
280+
'correa',
281+
marks=skip_xp_backends("array_api_strict",
282+
reason="Needs fancy indexing.")
283+
)
284+
])
262285
@pytest.mark.parametrize('dist', [norm, expon])
263286
def test_rmse_std(self, method, dist, xp):
264287
# test that RMSE and standard deviation of estimators matches values
265288
# given in differential_entropy reference [6]. Incidentally, also
266289
# tests vectorization.
267-
if is_jax(xp) and method == 'ebrahimi':
268-
pytest.xfail("Needs array assignment.")
269-
elif is_array_api_strict(xp) and method == 'correa':
270-
pytest.xfail("Needs fancy indexing.")
271-
272290
reps, n, m = 10000, 50, 7
273291
expected = self.rmse_std_cases[dist][method]
274292
rmse_expected, std_expected = xp.asarray(expected[0]), xp.asarray(expected[1])
@@ -282,12 +300,15 @@ def test_rmse_std(self, method, dist, xp):
282300
xp_test = array_namespace(res)
283301
xp_assert_close(xp_test.std(res, correction=0), std_expected, atol=0.002)
284302

285-
@pytest.mark.parametrize('n, method', [(8, 'van es'),
286-
(12, 'ebrahimi'),
287-
(1001, 'vasicek')])
303+
@pytest.mark.parametrize('n, method', [
304+
(8, 'van es'),
305+
pytest.param(
306+
12, 'ebrahimi',
307+
marks=skip_xp_backends("jax.numpy", reason="Needs array assignment")
308+
),
309+
(1001, 'vasicek')
310+
])
288311
def test_method_auto(self, n, method, xp):
289-
if is_jax(xp) and method == 'ebrahimi':
290-
pytest.xfail("Needs array assignment.")
291312
rvs = stats.norm.rvs(size=(n,), random_state=0)
292313
rvs = xp.asarray(rvs.tolist())
293314
res1 = stats.differential_entropy(rvs)
@@ -296,14 +317,20 @@ def test_method_auto(self, n, method, xp):
296317

297318
@pytest.mark.skip_xp_backends('jax.numpy',
298319
reason="JAX doesn't support item assignment")
299-
@pytest.mark.parametrize('method', ["vasicek", "van es", "correa", "ebrahimi"])
320+
@pytest.mark.parametrize('method', [
321+
"vasicek",
322+
"van es",
323+
pytest.param(
324+
"correa",
325+
marks=skip_xp_backends("array_api_strict", reason="Needs fancy indexing.")
326+
),
327+
"ebrahimi"
328+
])
300329
@pytest.mark.parametrize('dtype', [None, 'float32', 'float64'])
301330
def test_dtypes_gh21192(self, xp, method, dtype):
302331
# gh-21192 noted a change in the output of method='ebrahimi'
303332
# with integer input. Check that the output is consistent regardless
304333
# of input dtype.
305-
if is_array_api_strict(xp) and method == 'correa':
306-
pytest.xfail("Needs fancy indexing.")
307334
x = [1, 1, 2, 3, 3, 4, 5, 5, 6, 7, 8, 9, 10, 11]
308335
dtype_in = getattr(xp, str(dtype), None)
309336
dtype_out = getattr(xp, str(dtype), xp.asarray(1.).dtype)

0 commit comments

Comments
 (0)