Skip to content

Commit 2df89b2

Browse files
author
jax authors
committed
Merge pull request #20569 from jakevdp:fix-ks-test
PiperOrigin-RevId: 621608564
2 parents fed7efd + 31e2358 commit 2df89b2

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tests/random_lax_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ def _CheckKolmogorovSmirnovCDF(self, samples, cdf, pval=None):
6565
# whether RBG keys may be involved, but that's no longer exact.
6666
if config.enable_custom_prng.value and samples.dtype == jnp.bfloat16:
6767
return
68+
# kstest fails for infinities starting in scipy 1.12
69+
# (https://github.com/scipy/scipy/issues/20386)
70+
# TODO(jakevdp): remove this logic if/when fixed upstream.
71+
scipy_version = jtu.parse_version(scipy.__version__)
72+
if scipy_version >= (1, 12) and np.issubdtype(samples.dtype, np.floating):
73+
samples = np.array(samples, copy=True)
74+
samples[np.isposinf(samples)] = 0.01 * np.finfo(samples.dtype).max
75+
samples[np.isneginf(samples)] = 0.01 * np.finfo(samples.dtype).min
6876
self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)
6977

7078
def _CheckChiSquared(self, samples, pmf, *, pval=None):
@@ -742,9 +750,6 @@ def testParetoShape(self):
742750
)
743751
@jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times
744752
def testT(self, df, dtype):
745-
scipy_version = jtu.parse_version(scipy.__version__)
746-
if scipy_version >= (1, 13):
747-
self.skipTest("ks test returns NaN on SciPy 1.13")
748753
key = lambda: self.make_key(1)
749754
rand = lambda key, df: random.t(key, df, (10000,), dtype)
750755
crand = jax.jit(rand)

0 commit comments

Comments
 (0)