@@ -65,6 +65,14 @@ def _CheckKolmogorovSmirnovCDF(self, samples, cdf, pval=None):
65
65
# whether RBG keys may be involved, but that's no longer exact.
66
66
if config .enable_custom_prng .value and samples .dtype == jnp .bfloat16 :
67
67
return
68
+ # kstest fails for infinities starting in scipy 1.12
69
+ # (https://github.com/scipy/scipy/issues/20386)
70
+ # TODO(jakevdp): remove this logic if/when fixed upstream.
71
+ scipy_version = jtu .parse_version (scipy .__version__ )
72
+ if scipy_version >= (1 , 12 ) and np .issubdtype (samples .dtype , np .floating ):
73
+ samples = np .array (samples , copy = True )
74
+ samples [np .isposinf (samples )] = 0.01 * np .finfo (samples .dtype ).max
75
+ samples [np .isneginf (samples )] = 0.01 * np .finfo (samples .dtype ).min
68
76
self .assertGreater (scipy .stats .kstest (samples , cdf ).pvalue , fail_prob )
69
77
70
78
def _CheckChiSquared (self , samples , pmf , * , pval = None ):
@@ -742,9 +750,6 @@ def testParetoShape(self):
742
750
)
743
751
@jtu .skip_on_devices ("cpu" , "tpu" ) # TODO(phawkins): slow compilation times
744
752
def testT (self , df , dtype ):
745
- scipy_version = jtu .parse_version (scipy .__version__ )
746
- if scipy_version >= (1 , 13 ):
747
- self .skipTest ("ks test returns NaN on SciPy 1.13" )
748
753
key = lambda : self .make_key (1 )
749
754
rand = lambda key , df : random .t (key , df , (10000 ,), dtype )
750
755
crand = jax .jit (rand )
0 commit comments