Skip to content

Commit 6595229

Browse files
authored
FIX Lazy instantiate the ThreadpoolController (scikit-learn#29235)
1 parent 55ca335 commit 6595229

File tree

11 files changed

+70
-38
lines changed

11 files changed

+70
-38
lines changed

doc/whats_new/v1.5.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ Version 1.5.1
2323
Changelog
2424
---------
2525

26+
Changes impacting many modules
27+
------------------------------
28+
29+
- |Fix| Fixed a regression causing a dead-lock at import time in some settings.
30+
:pr:`29235` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
31+
2632
:mod:`sklearn.metrics`
2733
......................
2834

sklearn/__init__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,6 @@
7070
# We are not importing the rest of scikit-learn during the build
7171
# process, as it may not be compiled yet
7272
else:
73-
# Import numpy, scipy to make sure that the BLAS libs are loaded before
74-
# creating the ThreadpoolController. They would be imported just after
75-
# when importing utils anyway. This makes it explicit and robust to changes
76-
# in utils.
77-
# (OpenMP is loaded by importing show_versions right after this block)
78-
import numpy # noqa
79-
import scipy.linalg # noqa
80-
from threadpoolctl import ThreadpoolController
81-
8273
# `_distributor_init` allows distributors to run custom init code.
8374
# For instance, for the Windows wheel, this is used to pre-load the
8475
# vcomp shared library runtime for OpenMP embedded in the sklearn/.libs
@@ -147,12 +138,6 @@
147138
except ModuleNotFoundError:
148139
pass
149140

150-
# Set a global controller that can be used to locally limit the number of
151-
# threads without looping through all shared libraries every time.
152-
# This instantitation should not happen earlier because it needs all BLAS and
153-
# OpenMP libs to be loaded first.
154-
_threadpool_controller = ThreadpoolController()
155-
156141

157142
def setup_module(module):
158143
"""Fixture for the tests to assure globally controllable seeding of RNGs"""

sklearn/cluster/_kmeans.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy as np
1919
import scipy.sparse as sp
2020

21-
from .. import _threadpool_controller
2221
from ..base import (
2322
BaseEstimator,
2423
ClassNamePrefixFeaturesOutMixin,
@@ -32,6 +31,10 @@
3231
from ..utils._openmp_helpers import _openmp_effective_n_threads
3332
from ..utils._param_validation import Interval, StrOptions, validate_params
3433
from ..utils.extmath import row_norms, stable_cumsum
34+
from ..utils.parallel import (
35+
_get_threadpool_controller,
36+
_threadpool_controller_decorator,
37+
)
3538
from ..utils.sparsefuncs import mean_variance_axis
3639
from ..utils.sparsefuncs_fast import assign_rows_csr
3740
from ..utils.validation import (
@@ -624,7 +627,7 @@ def _kmeans_single_elkan(
624627

625628
# Threadpoolctl context to limit the number of threads in second level of
626629
# nested parallelism (i.e. BLAS) to avoid oversubscription.
627-
@_threadpool_controller.wrap(limits=1, user_api="blas")
630+
@_threadpool_controller_decorator(limits=1, user_api="blas")
628631
def _kmeans_single_lloyd(
629632
X,
630633
sample_weight,
@@ -827,7 +830,7 @@ def _labels_inertia(X, sample_weight, centers, n_threads=1, return_inertia=True)
827830

828831

829832
# Same as _labels_inertia but in a threadpool_limits context.
830-
_labels_inertia_threadpool_limit = _threadpool_controller.wrap(
833+
_labels_inertia_threadpool_limit = _threadpool_controller_decorator(
831834
limits=1, user_api="blas"
832835
)(_labels_inertia)
833836

@@ -922,7 +925,7 @@ def _check_mkl_vcomp(self, X, n_samples):
922925

923926
n_active_threads = int(np.ceil(n_samples / CHUNK_SIZE))
924927
if n_active_threads < self._n_threads:
925-
modules = _threadpool_controller.info()
928+
modules = _get_threadpool_controller().info()
926929
has_vcomp = "vcomp" in [module["prefix"] for module in modules]
927930
has_mkl = ("mkl", "intel") in [
928931
(module["internal_api"], module.get("threading_layer", None))
@@ -2144,7 +2147,7 @@ def fit(self, X, y=None, sample_weight=None):
21442147

21452148
n_steps = (self.max_iter * n_samples) // self._batch_size
21462149

2147-
with _threadpool_controller.limit(limits=1, user_api="blas"):
2150+
with _get_threadpool_controller().limit(limits=1, user_api="blas"):
21482151
# Perform the iterative optimization until convergence
21492152
for i in range(n_steps):
21502153
# Sample a minibatch from the full dataset
@@ -2270,7 +2273,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
22702273
# Initialize number of samples seen since last reassignment
22712274
self._n_since_last_reassign = 0
22722275

2273-
with _threadpool_controller.limit(limits=1, user_api="blas"):
2276+
with _get_threadpool_controller().limit(limits=1, user_api="blas"):
22742277
_mini_batch_step(
22752278
X,
22762279
sample_weight=sample_weight,

sklearn/cluster/tests/test_k_means.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pytest
99
from scipy import sparse as sp
1010

11-
from sklearn import _threadpool_controller
1211
from sklearn.base import clone
1312
from sklearn.cluster import KMeans, MiniBatchKMeans, k_means, kmeans_plusplus
1413
from sklearn.cluster._k_means_common import (
@@ -33,6 +32,7 @@
3332
)
3433
from sklearn.utils.extmath import row_norms
3534
from sklearn.utils.fixes import CSR_CONTAINERS
35+
from sklearn.utils.parallel import _get_threadpool_controller
3636

3737
# non centered, sparse centers to check the
3838
centers = np.array(
@@ -968,13 +968,13 @@ def test_result_equal_in_diff_n_threads(Estimator, global_random_seed):
968968
rnd = np.random.RandomState(global_random_seed)
969969
X = rnd.normal(size=(50, 10))
970970

971-
with _threadpool_controller.limit(limits=1, user_api="openmp"):
971+
with _get_threadpool_controller().limit(limits=1, user_api="openmp"):
972972
result_1 = (
973973
Estimator(n_clusters=n_clusters, random_state=global_random_seed)
974974
.fit(X)
975975
.labels_
976976
)
977-
with _threadpool_controller.limit(limits=2, user_api="openmp"):
977+
with _get_threadpool_controller().limit(limits=2, user_api="openmp"):
978978
result_2 = (
979979
Estimator(n_clusters=n_clusters, random_state=global_random_seed)
980980
.fit(X)

sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ from numbers import Integral
1414
from scipy.sparse import issparse
1515
from ...utils import check_array, check_scalar
1616
from ...utils.fixes import _in_unstable_openblas_configuration
17-
from ... import _threadpool_controller
17+
from ...utils.parallel import _get_threadpool_controller
1818

1919
{{for name_suffix in ['64', '32']}}
2020

@@ -58,7 +58,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
5858
"""
5959
# Limit the number of threads in second level of nested parallelism for BLAS
6060
# to avoid threads over-subscription (in DOT or GEMM for instance).
61-
with _threadpool_controller.limit(limits=1, user_api='blas'):
61+
with _get_threadpool_controller().limit(limits=1, user_api='blas'):
6262
if metric in ("euclidean", "sqeuclidean"):
6363
# Specialized implementation of ArgKmin for the Euclidean distance
6464
# for the dense-dense and sparse-sparse cases.

sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx.tp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ from libcpp.map cimport map as cpp_map, pair as cpp_pair
44
from libc.stdlib cimport free
55

66
from ...utils._typedefs cimport intp_t, float64_t
7-
from ... import _threadpool_controller
7+
from ...utils.parallel import _get_threadpool_controller
88

99
import numpy as np
1010
from scipy.sparse import issparse
@@ -66,7 +66,7 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
6666

6767
# Limit the number of threads in second level of nested parallelism for BLAS
6868
# to avoid threads over-subscription (in GEMM for instance).
69-
with _threadpool_controller.limit(limits=1, user_api="blas"):
69+
with _get_threadpool_controller().limit(limits=1, user_api="blas"):
7070
if pda.execute_in_parallel_on_Y:
7171
pda._parallel_on_Y()
7272
else:

sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from numbers import Real
1717
from scipy.sparse import issparse
1818
from ...utils import check_array, check_scalar
1919
from ...utils.fixes import _in_unstable_openblas_configuration
20-
from ... import _threadpool_controller
20+
from ...utils.parallel import _get_threadpool_controller
2121

2222
cnp.import_array()
2323

@@ -110,7 +110,7 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}})
110110

111111
# Limit the number of threads in second level of nested parallelism for BLAS
112112
# to avoid threads over-subscription (in GEMM for instance).
113-
with _threadpool_controller.limit(limits=1, user_api="blas"):
113+
with _get_threadpool_controller().limit(limits=1, user_api="blas"):
114114
if pda.execute_in_parallel_on_Y:
115115
pda._parallel_on_Y()
116116
else:

sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ from ...utils._typedefs cimport intp_t, float64_t
88

99
import numpy as np
1010
from scipy.sparse import issparse
11-
from ... import _threadpool_controller
11+
from ...utils.parallel import _get_threadpool_controller
1212

1313

1414
{{for name_suffix in ["32", "64"]}}
@@ -60,7 +60,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix}
6060

6161
# Limit the number of threads in second level of nested parallelism for BLAS
6262
# to avoid threads over-subscription (in GEMM for instance).
63-
with _threadpool_controller.limit(limits=1, user_api="blas"):
63+
with _get_threadpool_controller().limit(limits=1, user_api="blas"):
6464
if pda.execute_in_parallel_on_Y:
6565
pda._parallel_on_Y()
6666
else:

sklearn/metrics/tests/test_pairwise_distances_reduction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pytest
88
from scipy.spatial.distance import cdist
99

10-
from sklearn import _threadpool_controller
1110
from sklearn.metrics import euclidean_distances, pairwise_distances
1211
from sklearn.metrics._pairwise_distances_reduction import (
1312
ArgKmin,
@@ -23,6 +22,7 @@
2322
create_memmap_backed_data,
2423
)
2524
from sklearn.utils.fixes import CSR_CONTAINERS
25+
from sklearn.utils.parallel import _get_threadpool_controller
2626

2727
# Common supported metric between scipy.spatial.distance.cdist
2828
# and BaseDistanceReductionDispatcher.
@@ -1200,7 +1200,7 @@ def test_n_threads_agnosticism(
12001200
**compute_parameters,
12011201
)
12021202

1203-
with _threadpool_controller.limit(limits=1, user_api="openmp"):
1203+
with _get_threadpool_controller().limit(limits=1, user_api="openmp"):
12041204
dist, indices = Dispatcher.compute(
12051205
X,
12061206
Y,

sklearn/utils/fixes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
import scipy.sparse.linalg
2020
import scipy.stats
2121

22-
import sklearn
23-
2422
from ..externals._packaging.version import parse as parse_version
23+
from .parallel import _get_threadpool_controller
2524

2625
_IS_32BIT = 8 * struct.calcsize("P") == 32
2726
_IS_WASM = platform.machine() in ["wasm32", "wasm64"]
@@ -390,7 +389,7 @@ def _in_unstable_openblas_configuration():
390389
import numpy # noqa
391390
import scipy # noqa
392391

393-
modules_info = sklearn._threadpool_controller.info()
392+
modules_info = _get_threadpool_controller().info()
394393

395394
open_blas_used = any(info["internal_api"] == "openblas" for info in modules_info)
396395
if not open_blas_used:

0 commit comments

Comments
 (0)