6
6
7
7
from scipy import stats
8
8
from scipy .stats import norm , expon # type: ignore[attr-defined]
9
- from scipy ._lib ._array_api import array_namespace , is_array_api_strict , is_jax
9
+ from scipy ._lib ._array_api import array_namespace
10
10
from scipy ._lib ._array_api_no_0d import (xp_assert_close , xp_assert_equal ,
11
11
xp_assert_less )
12
12
13
+ skip_xp_backends = pytest .mark .skip_xp_backends
14
+
13
15
class TestEntropy :
14
16
def test_entropy_positive (self , xp ):
15
17
# See ticket #497
@@ -224,13 +226,21 @@ def test_input_validation(self, xp):
224
226
with pytest .raises (ValueError , match = message ):
225
227
stats .differential_entropy (x , method = 'ekki-ekki' )
226
228
227
- @pytest .mark .parametrize ('method' , ['vasicek' , 'van es' ,
228
- 'ebrahimi' , 'correa' ])
229
+ @pytest .mark .parametrize ('method' , [
230
+ 'vasicek' ,
231
+ 'van es' ,
232
+ pytest .param (
233
+ 'ebrahimi' ,
234
+ marks = skip_xp_backends ("jax.numpy" ,
235
+ reason = "JAX doesn't support item assignment" )
236
+ ),
237
+ pytest .param (
238
+ 'correa' ,
239
+ marks = skip_xp_backends ("array_api_strict" ,
240
+ reason = "Needs fancy indexing." )
241
+ )
242
+ ])
229
243
def test_consistency (self , method , xp ):
230
- if is_jax (xp ) and method == 'ebrahimi' :
231
- pytest .xfail ("Needs array assignment." )
232
- elif is_array_api_strict (xp ) and method == 'correa' :
233
- pytest .xfail ("Needs fancy indexing." )
234
244
# test that method is a consistent estimator
235
245
n = 10000 if method == 'correa' else 1000000
236
246
rvs = stats .norm .rvs (size = n , random_state = 0 )
@@ -258,17 +268,25 @@ def test_consistency(self, method, xp):
258
268
rmse_std_cases = {norm : norm_rmse_std_cases ,
259
269
expon : expon_rmse_std_cases }
260
270
261
- @pytest .mark .parametrize ('method' , ['vasicek' , 'van es' , 'ebrahimi' , 'correa' ])
271
+ @pytest .mark .parametrize ('method' , [
272
+ 'vasicek' ,
273
+ 'van es' ,
274
+ pytest .param (
275
+ 'ebrahimi' ,
276
+ marks = skip_xp_backends ("jax.numpy" ,
277
+ reason = "JAX doesn't support item assignment" )
278
+ ),
279
+ pytest .param (
280
+ 'correa' ,
281
+ marks = skip_xp_backends ("array_api_strict" ,
282
+ reason = "Needs fancy indexing." )
283
+ )
284
+ ])
262
285
@pytest .mark .parametrize ('dist' , [norm , expon ])
263
286
def test_rmse_std (self , method , dist , xp ):
264
287
# test that RMSE and standard deviation of estimators matches values
265
288
# given in differential_entropy reference [6]. Incidentally, also
266
289
# tests vectorization.
267
- if is_jax (xp ) and method == 'ebrahimi' :
268
- pytest .xfail ("Needs array assignment." )
269
- elif is_array_api_strict (xp ) and method == 'correa' :
270
- pytest .xfail ("Needs fancy indexing." )
271
-
272
290
reps , n , m = 10000 , 50 , 7
273
291
expected = self .rmse_std_cases [dist ][method ]
274
292
rmse_expected , std_expected = xp .asarray (expected [0 ]), xp .asarray (expected [1 ])
@@ -282,12 +300,15 @@ def test_rmse_std(self, method, dist, xp):
282
300
xp_test = array_namespace (res )
283
301
xp_assert_close (xp_test .std (res , correction = 0 ), std_expected , atol = 0.002 )
284
302
285
- @pytest .mark .parametrize ('n, method' , [(8 , 'van es' ),
286
- (12 , 'ebrahimi' ),
287
- (1001 , 'vasicek' )])
303
+ @pytest .mark .parametrize ('n, method' , [
304
+ (8 , 'van es' ),
305
+ pytest .param (
306
+ 12 , 'ebrahimi' ,
307
+ marks = skip_xp_backends ("jax.numpy" , reason = "Needs array assignment" )
308
+ ),
309
+ (1001 , 'vasicek' )
310
+ ])
288
311
def test_method_auto (self , n , method , xp ):
289
- if is_jax (xp ) and method == 'ebrahimi' :
290
- pytest .xfail ("Needs array assignment." )
291
312
rvs = stats .norm .rvs (size = (n ,), random_state = 0 )
292
313
rvs = xp .asarray (rvs .tolist ())
293
314
res1 = stats .differential_entropy (rvs )
@@ -296,14 +317,20 @@ def test_method_auto(self, n, method, xp):
296
317
297
318
@pytest .mark .skip_xp_backends ('jax.numpy' ,
298
319
reason = "JAX doesn't support item assignment" )
299
- @pytest .mark .parametrize ('method' , ["vasicek" , "van es" , "correa" , "ebrahimi" ])
320
+ @pytest .mark .parametrize ('method' , [
321
+ "vasicek" ,
322
+ "van es" ,
323
+ pytest .param (
324
+ "correa" ,
325
+ marks = skip_xp_backends ("array_api_strict" , reason = "Needs fancy indexing." )
326
+ ),
327
+ "ebrahimi"
328
+ ])
300
329
@pytest .mark .parametrize ('dtype' , [None , 'float32' , 'float64' ])
301
330
def test_dtypes_gh21192 (self , xp , method , dtype ):
302
331
# gh-21192 noted a change in the output of method='ebrahimi'
303
332
# with integer input. Check that the output is consistent regardless
304
333
# of input dtype.
305
- if is_array_api_strict (xp ) and method == 'correa' :
306
- pytest .xfail ("Needs fancy indexing." )
307
334
x = [1 , 1 , 2 , 3 , 3 , 4 , 5 , 5 , 6 , 7 , 8 , 9 , 10 , 11 ]
308
335
dtype_in = getattr (xp , str (dtype ), None )
309
336
dtype_out = getattr (xp , str (dtype ), xp .asarray (1. ).dtype )
0 commit comments