26
26
fetch_rcv1 ,
27
27
fetch_species_distributions ,
28
28
)
29
- from sklearn .tests import random_seed
30
29
from sklearn .utils ._testing import get_pytest_filterwarning_lines
31
30
from sklearn .utils .fixes import (
32
31
_IS_32BIT ,
@@ -265,6 +264,51 @@ def pyplot():
265
264
pyplot .close ("all" )
266
265
267
266
267
+ def pytest_generate_tests (metafunc ):
268
+ """Parametrization of global_random_seed fixture
269
+
270
+ based on the SKLEARN_TESTS_GLOBAL_RANDOM_SEED environment variable.
271
+
272
+ The goal of this fixture is to prevent tests that use it to be sensitive
273
+ to a specific seed value while still being deterministic by default.
274
+
275
+ See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED
276
+ variable for instructions on how to use this fixture.
277
+
278
+ https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed
279
+
280
+ """
281
+ # When using pytest-xdist this function is called in the xdist workers.
282
+ # We rely on SKLEARN_TESTS_GLOBAL_RANDOM_SEED environment variable which is
283
+ # set in before running pytest and is available in xdist workers since they
284
+ # are subprocesses.
285
+ RANDOM_SEED_RANGE = list (range (100 )) # All seeds in [0, 99] should be valid.
286
+ random_seed_var = environ .get ("SKLEARN_TESTS_GLOBAL_RANDOM_SEED" )
287
+
288
+ default_random_seeds = [42 ]
289
+
290
+ if random_seed_var is None :
291
+ random_seeds = default_random_seeds
292
+ elif random_seed_var == "all" :
293
+ random_seeds = RANDOM_SEED_RANGE
294
+ else :
295
+ if "-" in random_seed_var :
296
+ start , stop = random_seed_var .split ("-" )
297
+ random_seeds = list (range (int (start ), int (stop ) + 1 ))
298
+ else :
299
+ random_seeds = [int (random_seed_var )]
300
+
301
+ if min (random_seeds ) < 0 or max (random_seeds ) > 99 :
302
+ raise ValueError (
303
+ "The value(s) of the environment variable "
304
+ "SKLEARN_TESTS_GLOBAL_RANDOM_SEED must be in the range [0, 99] "
305
+ f"(or 'all'), got: { random_seed_var } "
306
+ )
307
+
308
+ if "global_random_seed" in metafunc .fixturenames :
309
+ metafunc .parametrize ("global_random_seed" , random_seeds )
310
+
311
+
268
312
def pytest_configure (config ):
269
313
# Use matplotlib agg backend during the tests including doctests
270
314
try :
@@ -282,10 +326,6 @@ def pytest_configure(config):
282
326
allowed_parallelism = max (allowed_parallelism // int (xdist_worker_count ), 1 )
283
327
threadpool_limits (allowed_parallelism )
284
328
285
- # Register global_random_seed plugin if it is not already registered
286
- if not config .pluginmanager .hasplugin ("sklearn.tests.random_seed" ):
287
- config .pluginmanager .register (random_seed )
288
-
289
329
if environ .get ("SKLEARN_WARNINGS_AS_ERRORS" , "0" ) != "0" :
290
330
# This seems like the only way to programmatically change the config
291
331
# filterwarnings. This was suggested in
0 commit comments