Skip to content

Commit 87941ab

Browse files
MAINT Parameters validation for sklearn.utils.gen_even_slices (scikit-learn#26682)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 876c235 commit 87941ab

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ def _check_function_param_validation(
307307
"sklearn.tree.export_text",
308308
"sklearn.tree.plot_tree",
309309
"sklearn.utils.gen_batches",
310+
"sklearn.utils.gen_even_slices",
311+
"sklearn.utils.graph.single_source_shortest_path_length",
310312
"sklearn.utils.resample",
311313
"sklearn.utils.safe_mask",
312314
"sklearn.utils.extmath.randomized_svd",

sklearn/utils/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from . import _joblib, metadata_routing
2020
from ._bunch import Bunch
2121
from ._estimator_html_repr import estimator_html_repr
22-
from ._param_validation import Interval, validate_params
22+
from ._param_validation import Integral, Interval, validate_params
2323
from .class_weight import compute_class_weight, compute_sample_weight
2424
from .deprecation import deprecated
2525
from .discovery import all_estimators
@@ -796,6 +796,14 @@ def gen_batches(n, batch_size, *, min_batch_size=0):
796796
yield slice(start, n)
797797

798798

799+
@validate_params(
800+
{
801+
"n": [Interval(Integral, 1, None, closed="left")],
802+
"n_packs": [Interval(Integral, 1, None, closed="left")],
803+
"n_samples": [Interval(Integral, 1, None, closed="left"), None],
804+
},
805+
prefer_skip_nested_validation=True,
806+
)
799807
def gen_even_slices(n, n_packs, *, n_samples=None):
800808
"""Generator to create `n_packs` evenly spaced slices going up to `n`.
801809
@@ -835,8 +843,6 @@ def gen_even_slices(n, n_packs, *, n_samples=None):
835843
[slice(0, 4, None), slice(4, 7, None), slice(7, 10, None)]
836844
"""
837845
start = 0
838-
if n_packs < 1:
839-
raise ValueError("gen_even_slices got n_packs=%s, must be >=1" % n_packs)
840846
for pack_num in range(n_packs):
841847
this_n = n // n_packs
842848
if pack_num < n % n_packs:

sklearn/utils/tests/test_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,6 @@ def test_gen_even_slices():
545545
joined_range = list(chain(*[some_range[slice] for slice in gen_even_slices(10, 3)]))
546546
assert_array_equal(some_range, joined_range)
547547

548-
# check that passing negative n_chunks raises an error
549-
slices = gen_even_slices(10, -1)
550-
with pytest.raises(ValueError, match="gen_even_slices got n_packs=-1, must be >=1"):
551-
next(slices)
552-
553548

554549
@pytest.mark.parametrize(
555550
("row_bytes", "max_n_rows", "working_memory", "expected"),

0 commit comments

Comments
 (0)