Skip to content

Commit f40d2cd

Browse files
authored
Merge pull request scipy#22175 from rgommers/stats-nogil-threadsafety
MAINT: stats: fix thread-safety issues under free-threaded CPython
2 parents f464cc1 + 978a047 commit f40d2cd

File tree

6 files changed

+56
-31
lines changed

6 files changed

+56
-31
lines changed

scipy/stats/_kde.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#-------------------------------------------------------------------------------
1919

2020
# Standard library imports.
21+
import threading
2122
import warnings
2223

2324
# SciPy imports.
@@ -36,6 +37,8 @@
3637

3738
__all__ = ['gaussian_kde']
3839

40+
MVN_LOCK = threading.Lock()
41+
3942

4043
class gaussian_kde:
4144
"""Representation of a kernel-density estimate using Gaussian kernels.
@@ -384,9 +387,10 @@ def integrate_box(self, low_bounds, high_bounds, maxpts=None):
384387
else:
385388
extra_kwds = {}
386389

387-
value, inform = _mvn.mvnun_weighted(low_bounds, high_bounds,
388-
self.dataset, self.weights,
389-
self.covariance, **extra_kwds)
390+
with MVN_LOCK:
391+
value, inform = _mvn.mvnun_weighted(low_bounds, high_bounds,
392+
self.dataset, self.weights,
393+
self.covariance, **extra_kwds)
390394
if inform:
391395
msg = f'An integral in _mvn.mvnun requires more points than {self.d * 1000}'
392396
warnings.warn(msg, stacklevel=2)

scipy/stats/_mannwhitneyu.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import threading
12
import numpy as np
23
from collections import namedtuple
34
from scipy import special
@@ -143,7 +144,10 @@ def build_u_freqs_array(self, maxu):
143144
return configurations / total
144145

145146

146-
_mwu_state = _MWU(0, 0)
147+
# Maintain state for faster repeat calls to `mannwhitneyu`.
148+
# _MWU() is calculated once per thread and stored as an attribute on
149+
# this thread-local variable inside mannwhitneyu().
150+
_mwu_state = threading.local()
147151

148152

149153
def _get_mwu_z(U, n1, n2, t, axis=0, continuity=True):
@@ -461,8 +465,10 @@ def mannwhitneyu(x, y, use_continuity=True, alternative="two-sided",
461465
method = _mwu_choose_method(n1, n2, np.any(t > 1))
462466

463467
if method == "exact":
464-
_mwu_state.set_shapes(n1, n2)
465-
p = _mwu_state.sf(U.astype(int))
468+
if not hasattr(_mwu_state, 's'):
469+
_mwu_state.s = _MWU(0, 0)
470+
_mwu_state.s.set_shapes(n1, n2)
471+
p = _mwu_state.s.sf(U.astype(int))
466472
elif method == "asymptotic":
467473
z = _get_mwu_z(U, n1, n2, t, continuity=use_continuity)
468474
p = stats.norm.sf(z)

scipy/stats/_morestats.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import warnings
3+
import threading
34
from collections import namedtuple
45

56
import numpy as np
@@ -962,7 +963,7 @@ def boxcox_llf(lmb, data):
962963
if xp.isdtype(dt, 'integral'):
963964
data = xp.asarray(data, dtype=xp.float64)
964965
dt = xp.float64
965-
966+
966967
logdata = xp.log(data)
967968

968969
# Compute the variance of the transformed data.
@@ -2628,7 +2629,9 @@ def sf(self, k, n, m):
26282629

26292630

26302631
# Maintain state for faster repeat calls to ansari w/ method='exact'
2631-
_abw_state = _ABW()
2632+
# _ABW() is calculated once per thread and stored as an attribute on
2633+
# this thread-local variable inside ansari().
2634+
_abw_state = threading.local()
26322635

26332636

26342637
@_axis_nan_policy_factory(AnsariResult, n_samples=2)
@@ -2739,6 +2742,10 @@ def ansari(x, y, alternative='two-sided'):
27392742
if alternative not in {'two-sided', 'greater', 'less'}:
27402743
raise ValueError("'alternative' must be 'two-sided',"
27412744
" 'greater', or 'less'.")
2745+
2746+
if not hasattr(_abw_state, 'a'):
2747+
_abw_state.a = _ABW()
2748+
27422749
x, y = asarray(x), asarray(y)
27432750
n = len(x)
27442751
m = len(y)
@@ -2759,14 +2766,14 @@ def ansari(x, y, alternative='two-sided'):
27592766
warnings.warn("Ties preclude use of exact statistic.", stacklevel=2)
27602767
if exact:
27612768
if alternative == 'two-sided':
2762-
pval = 2.0 * np.minimum(_abw_state.cdf(AB, n, m),
2763-
_abw_state.sf(AB, n, m))
2769+
pval = 2.0 * np.minimum(_abw_state.a.cdf(AB, n, m),
2770+
_abw_state.a.sf(AB, n, m))
27642771
elif alternative == 'greater':
27652772
# AB statistic is _smaller_ when ratio of scales is larger,
27662773
# so this is the opposite of the usual calculation
2767-
pval = _abw_state.cdf(AB, n, m)
2774+
pval = _abw_state.a.cdf(AB, n, m)
27682775
else:
2769-
pval = _abw_state.sf(AB, n, m)
2776+
pval = _abw_state.a.sf(AB, n, m)
27702777
return AnsariResult(AB, min(1.0, pval))
27712778

27722779
# otherwise compute normal approximation
@@ -4359,7 +4366,7 @@ def directional_stats(samples, *, axis=0, normalize=True):
43594366
"""
43604367
xp = array_namespace(samples)
43614368
samples = xp.asarray(samples)
4362-
4369+
43634370
if samples.ndim < 2:
43644371
raise ValueError("samples must at least be two-dimensional. "
43654372
f"Instead samples has shape: {tuple(samples.shape)}")

scipy/stats/_multivariate.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Author: Joris Vankerschaver 2013
33
#
44
import math
5+
import threading
56
import numpy as np
67
import scipy.linalg
78
from scipy._lib import doccer
@@ -38,6 +39,7 @@
3839
_LOG_2PI = np.log(2 * np.pi)
3940
_LOG_2 = np.log(2)
4041
_LOG_PI = np.log(np.pi)
42+
MVN_LOCK = threading.Lock()
4143

4244

4345
_doc_random_state = """\
@@ -638,8 +640,9 @@ def _cdf(self, x, mean, cov, maxpts, abseps, releps, lower_limit):
638640

639641
# mvnun expects 1-d arguments, so process points sequentially
640642
def func1d(limits):
641-
return _mvn.mvnun(limits[:n], limits[n:], mean, cov,
642-
maxpts, abseps, releps)[0]
643+
with MVN_LOCK:
644+
return _mvn.mvnun(limits[:n], limits[n:], mean, cov,
645+
maxpts, abseps, releps)[0]
643646

644647
out = np.apply_along_axis(func1d, -1, limits) * signs
645648
return _squeeze_output(out)

scipy/stats/tests/test_hypotests.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
_cdf_cvm, cramervonmises_2samp,
1515
_pval_cvm_2samp_exact, barnard_exact,
1616
boschloo_exact)
17-
from scipy.stats._mannwhitneyu import mannwhitneyu, _mwu_state
17+
from scipy.stats._mannwhitneyu import mannwhitneyu, _mwu_state, _MWU
1818
from .common_tests import check_named_results
1919
from scipy._lib._testutils import _TestPythranFunc
2020
from scipy.stats._axis_nan_policy import SmallSampleWarning, too_small_1d_not_omit
@@ -367,28 +367,30 @@ def test_tie_correct(self):
367367

368368
def test_exact_distribution(self):
369369
# I considered parametrize. I decided against it.
370+
setattr(_mwu_state, 's', _MWU(0, 0))
371+
370372
p_tables = {3: self.pn3, 4: self.pn4, 5: self.pm5, 6: self.pm6}
371373
for n, table in p_tables.items():
372374
for m, p in table.items():
373375
# check p-value against table
374376
u = np.arange(0, len(p))
375-
_mwu_state.set_shapes(m, n)
376-
assert_allclose(_mwu_state.cdf(k=u), p, atol=1e-3)
377+
_mwu_state.s.set_shapes(m, n)
378+
assert_allclose(_mwu_state.s.cdf(k=u), p, atol=1e-3)
377379

378380
# check identity CDF + SF - PMF = 1
379381
# ( In this implementation, SF(U) includes PMF(U) )
380382
u2 = np.arange(0, m*n+1)
381-
assert_allclose(_mwu_state.cdf(k=u2)
382-
+ _mwu_state.sf(k=u2)
383-
- _mwu_state.pmf(k=u2), 1)
383+
assert_allclose(_mwu_state.s.cdf(k=u2)
384+
+ _mwu_state.s.sf(k=u2)
385+
- _mwu_state.s.pmf(k=u2), 1)
384386

385387
# check symmetry about mean of U, i.e. pmf(U) = pmf(m*n-U)
386-
pmf = _mwu_state.pmf(k=u2)
388+
pmf = _mwu_state.s.pmf(k=u2)
387389
assert_allclose(pmf, pmf[::-1])
388390

389391
# check symmetry w.r.t. interchange of m, n
390-
_mwu_state.set_shapes(n, m)
391-
pmf2 = _mwu_state.pmf(k=u2)
392+
_mwu_state.s.set_shapes(n, m)
393+
pmf2 = _mwu_state.s.pmf(k=u2)
392394
assert_allclose(pmf, pmf2)
393395

394396
def test_asymptotic_behavior(self):
@@ -628,22 +630,25 @@ def test_gh19692_smaller_table(self):
628630
m, n = 5, 11
629631
x = rng.random(size=m)
630632
y = rng.random(size=n)
631-
_mwu_state.reset() # reset cache
633+
634+
setattr(_mwu_state, 's', _MWU(0, 0))
635+
_mwu_state.s.reset() # reset cache
636+
632637
res = stats.mannwhitneyu(x, y, method='exact')
633-
shape = _mwu_state.configurations.shape
638+
shape = _mwu_state.s.configurations.shape
634639
assert shape[-1] == min(res.statistic, m*n - res.statistic) + 1
635640
stats.mannwhitneyu(y, x, method='exact')
636-
assert shape == _mwu_state.configurations.shape # same when sizes are reversed
641+
assert shape == _mwu_state.s.configurations.shape # same with reversed sizes
637642

638643
# Also, we weren't exploiting the symmetry of the null distribution
639644
# to its full potential. Ensure that the null distribution is not
640645
# evaluated explicitly for `k > m*n/2`.
641-
_mwu_state.reset() # reset cache
646+
_mwu_state.s.reset() # reset cache
642647
stats.mannwhitneyu(x, 0*y, method='exact', alternative='greater')
643-
shape = _mwu_state.configurations.shape
648+
shape = _mwu_state.s.configurations.shape
644649
assert shape[-1] == 1 # k is smallest possible
645650
stats.mannwhitneyu(0*x, y, method='exact', alternative='greater')
646-
assert shape == _mwu_state.configurations.shape
651+
assert shape == _mwu_state.s.configurations.shape
647652

648653
@pytest.mark.parametrize('alternative', ['less', 'greater', 'two-sided'])
649654
def test_permutation_method(self, alternative):

scipy/stats/tests/test_morestats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def test_alternative_exact(self):
671671
assert pval_g < 0.05 # level of significance.
672672
# also check if the p-values sum up to 1 plus the probability
673673
# mass under the calculated statistic.
674-
prob = _abw_state.pmf(statistic, len(x1), len(x2))
674+
prob = _abw_state.a.pmf(statistic, len(x1), len(x2))
675675
assert_allclose(pval_g + pval_l, 1 + prob, atol=1e-12)
676676
# also check if one of the one-sided p-value equals half the
677677
# two-sided p-value and the other one-sided p-value is its

0 commit comments

Comments
 (0)