Skip to content

Commit 732832a

Browse files
mdhabertylerjereddy
authored andcommitted
MAINT: stats.Mixture: fix inverse functions when mean is undefined (scipy#22337)
* MAINT: stats.Mixture: fix inverse when mean is undefined * Apply suggestions from code review * Update scipy/stats/_distribution_infrastructure.py
1 parent 367bc40 commit 732832a

File tree

2 files changed

+39
-19
lines changed

2 files changed

+39
-19
lines changed

scipy/stats/_distribution_infrastructure.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,25 @@ def _logexpxmexpy(x, y):
12121212
return res
12131213

12141214

1215+
def _guess_bracket(xmin, xmax):
1216+
a = np.full_like(xmin, -1.0)
1217+
b = np.ones_like(xmax)
1218+
1219+
i = np.isfinite(xmin) & np.isfinite(xmax)
1220+
a[i] = xmin[i]
1221+
b[i] = xmax[i]
1222+
1223+
i = np.isfinite(xmin) & ~np.isfinite(xmax)
1224+
a[i] = xmin[i]
1225+
b[i] = xmin[i] + 1
1226+
1227+
i = np.isfinite(xmax) & ~np.isfinite(xmin)
1228+
a[i] = xmax[i] - 1
1229+
b[i] = xmax[i]
1230+
1231+
return a, b
1232+
1233+
12151234
def _log_real_standardize(x):
12161235
"""Standardizes the (complex) logarithm of a real number.
12171236
@@ -2011,27 +2030,13 @@ def f2(x, _p, **kwargs): # named `_p` to avoid conflict with shape `p`
20112030
shape = xmin.shape
20122031
xmin, xmax = np.atleast_1d(xmin, xmax)
20132032

2014-
a = -np.ones_like(xmin)
2015-
b = np.ones_like(xmax)
2016-
2017-
i = np.isfinite(xmin) & np.isfinite(xmax)
2018-
a[i] = xmin[i]
2019-
b[i] = xmax[i]
2020-
2021-
i = np.isfinite(xmin) & ~np.isfinite(xmax)
2022-
a[i] = xmin[i]
2023-
b[i] = xmin[i] + 1
2024-
2025-
i = np.isfinite(xmax) & ~np.isfinite(xmin)
2026-
a[i] = xmax[i] - 1
2027-
b[i] = xmax[i]
2028-
2033+
xl0, xr0 = _guess_bracket(xmin, xmax)
20292034
xmin = xmin.reshape(shape)
20302035
xmax = xmax.reshape(shape)
2031-
a = a.reshape(shape)
2032-
b = b.reshape(shape)
2036+
xl0 = xl0.reshape(shape)
2037+
xr0 = xr0.reshape(shape)
20332038

2034-
res = _bracket_root(f3, xl0=a, xr0=b, xmin=xmin, xmax=xmax, args=args)
2039+
res = _bracket_root(f3, xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, args=args)
20352040
# For now, we ignore the status, but I want to use the bracket width
20362041
# as an error estimate - see question 5 at the top.
20372042
xrtol = None if _isnull(self.tol) else self.tol
@@ -4636,7 +4641,8 @@ def _invert(self, fun, p):
46364641
xmin, xmax = self.support()
46374642
fun = getattr(self, fun)
46384643
f = lambda x, p: fun(x) - p # noqa: E731 is silly
4639-
res = _bracket_root(f, xl0=self.mean(), xmin=xmin, xmax=xmax, args=(p,))
4644+
xl0, xr0 = _guess_bracket(xmin, xmax)
4645+
res = _bracket_root(f, xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, args=(p,))
46404646
return _chandrupatla(f, a=res.xl, b=res.xr, args=(p,)).x
46414647

46424648
def icdf(self, p, /, *, method=None):

scipy/stats/tests/test_continuous.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,3 +1863,17 @@ def test_properties(self):
18631863
assert X.components[0] == components[0]
18641864
X.weights[0] = weights[1]
18651865
assert X.weights[0] == weights[0]
1866+
1867+
def test_inverse(self):
1868+
# Originally, inverse relied on the mean to start the bracket search.
1869+
# This didn't work for distributions with non-finite mean. Check that
1870+
# this is resolved.
1871+
rng = np.random.default_rng(24358934657854237863456)
1872+
Cauchy = stats.make_distribution(stats.cauchy)
1873+
X0 = Cauchy()
1874+
X = stats.Mixture([X0, X0])
1875+
p = rng.random(size=10)
1876+
np.testing.assert_allclose(X.icdf(p), X0.icdf(p))
1877+
np.testing.assert_allclose(X.iccdf(p), X0.iccdf(p))
1878+
np.testing.assert_allclose(X.ilogcdf(p), X0.ilogcdf(p))
1879+
np.testing.assert_allclose(X.ilogccdf(p), X0.ilogccdf(p))

0 commit comments

Comments
 (0)