Skip to content

Commit a4ebe19

Browse files
lesteveogriseljeremiedbb
authored
ENH Simplify pytest global random test plugin (scikit-learn#27963)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent 2107404 commit a4ebe19

File tree

5 files changed

+55
-105
lines changed

5 files changed

+55
-105
lines changed

build_tools/azure/test_script.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ if [[ "$BUILD_REASON" == "Schedule" ]]; then
1111
# Enable global random seed randomization to discover seed-sensitive tests
1212
# only on nightly builds.
1313
# https://scikit-learn.org/stable/computing/parallelism.html#environment-variables
14-
export SKLEARN_TESTS_GLOBAL_RANDOM_SEED="any"
14+
export SKLEARN_TESTS_GLOBAL_RANDOM_SEED=$(($RANDOM % 100))
15+
echo "To reproduce this test run, set the following environment variable:"
16+
echo " SKLEARN_TESTS_GLOBAL_RANDOM_SEED=$SKLEARN_TESTS_GLOBAL_RANDOM_SEED",
17+
echo "See: https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed"
1518

1619
# Enable global dtype fixture for all nightly builds to discover
1720
# numerical-sensitive tests.

doc/computing/parallelism.rst

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,14 +232,12 @@ the `global_random_seed`` fixture.
232232
All tests that use this fixture accept the contract that they should
233233
deterministically pass for any seed value from 0 to 99 included.
234234

235-
If the `SKLEARN_TESTS_GLOBAL_RANDOM_SEED` environment variable is set to
236-
`"any"` (which should be the case on nightly builds on the CI), the fixture
237-
will choose an arbitrary seed in the above range (based on the BUILD_NUMBER or
238-
the current day) and all fixtured tests will run for that specific seed. The
239-
goal is to ensure that, over time, our CI will run all tests with different
240-
seeds while keeping the test duration of a single run of the full test suite
241-
limited. This will check that the assertions of tests written to use this
242-
fixture are not dependent on a specific seed value.
235+
In nightly CI builds, the `SKLEARN_TESTS_GLOBAL_RANDOM_SEED` environment
236+
variable is drawn randomly in the above range and all fixtured tests will run
237+
for that specific seed. The goal is to ensure that, over time, our CI will run
238+
all tests with different seeds while keeping the test duration of a single run
239+
of the full test suite limited. This will check that the assertions of tests
240+
written to use this fixture are not dependent on a specific seed value.
243241

244242
The range of admissible seed values is limited to [0, 99] because it is often
245243
not possible to write a test that can work for any possible seed and we want to
@@ -250,8 +248,6 @@ Valid values for `SKLEARN_TESTS_GLOBAL_RANDOM_SEED`:
250248
- `SKLEARN_TESTS_GLOBAL_RANDOM_SEED="42"`: run tests with a fixed seed of 42
251249
- `SKLEARN_TESTS_GLOBAL_RANDOM_SEED="40-42"`: run the tests with all seeds
252250
between 40 and 42 included
253-
- `SKLEARN_TESTS_GLOBAL_RANDOM_SEED="any"`: run the tests with an arbitrary
254-
seed selected between 0 and 99 included
255251
- `SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all"`: run the tests with all seeds
256252
between 0 and 99 included. This can take a long time: only use for individual
257253
tests, not the full test suite!

setup.cfg

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ addopts =
1616
--doctest-modules
1717
--disable-pytest-warnings
1818
--color=yes
19-
# Activate the plugin explicitly to ensure that the seed is reported
20-
# correctly on the CI when running `pytest --pyargs sklearn` from the
21-
# source folder.
22-
-p sklearn.tests.random_seed
2319

2420
[mypy]
2521
ignore_missing_imports = True

sklearn/conftest.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
fetch_rcv1,
2727
fetch_species_distributions,
2828
)
29-
from sklearn.tests import random_seed
3029
from sklearn.utils._testing import get_pytest_filterwarning_lines
3130
from sklearn.utils.fixes import (
3231
_IS_32BIT,
@@ -265,6 +264,51 @@ def pyplot():
265264
pyplot.close("all")
266265

267266

267+
def pytest_generate_tests(metafunc):
268+
"""Parametrization of global_random_seed fixture
269+
270+
based on the SKLEARN_TESTS_GLOBAL_RANDOM_SEED environment variable.
271+
272+
The goal of this fixture is to prevent tests that use it to be sensitive
273+
to a specific seed value while still being deterministic by default.
274+
275+
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED
276+
variable for instructions on how to use this fixture.
277+
278+
https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed
279+
280+
"""
281+
# When using pytest-xdist this function is called in the xdist workers.
282+
# We rely on SKLEARN_TESTS_GLOBAL_RANDOM_SEED environment variable which is
283+
# set in before running pytest and is available in xdist workers since they
284+
# are subprocesses.
285+
RANDOM_SEED_RANGE = list(range(100)) # All seeds in [0, 99] should be valid.
286+
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED")
287+
288+
default_random_seeds = [42]
289+
290+
if random_seed_var is None:
291+
random_seeds = default_random_seeds
292+
elif random_seed_var == "all":
293+
random_seeds = RANDOM_SEED_RANGE
294+
else:
295+
if "-" in random_seed_var:
296+
start, stop = random_seed_var.split("-")
297+
random_seeds = list(range(int(start), int(stop) + 1))
298+
else:
299+
random_seeds = [int(random_seed_var)]
300+
301+
if min(random_seeds) < 0 or max(random_seeds) > 99:
302+
raise ValueError(
303+
"The value(s) of the environment variable "
304+
"SKLEARN_TESTS_GLOBAL_RANDOM_SEED must be in the range [0, 99] "
305+
f"(or 'all'), got: {random_seed_var}"
306+
)
307+
308+
if "global_random_seed" in metafunc.fixturenames:
309+
metafunc.parametrize("global_random_seed", random_seeds)
310+
311+
268312
def pytest_configure(config):
269313
# Use matplotlib agg backend during the tests including doctests
270314
try:
@@ -282,10 +326,6 @@ def pytest_configure(config):
282326
allowed_parallelism = max(allowed_parallelism // int(xdist_worker_count), 1)
283327
threadpool_limits(allowed_parallelism)
284328

285-
# Register global_random_seed plugin if it is not already registered
286-
if not config.pluginmanager.hasplugin("sklearn.tests.random_seed"):
287-
config.pluginmanager.register(random_seed)
288-
289329
if environ.get("SKLEARN_WARNINGS_AS_ERRORS", "0") != "0":
290330
# This seems like the only way to programmatically change the config
291331
# filterwarnings. This was suggested in

sklearn/tests/random_seed.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

0 commit comments

Comments
 (0)