Skip to content

Commit 876c235

Browse files
MAINT Parameters validation for sklearn.utils.extmath.randomized_svd (scikit-learn#26690)
Co-authored-by: shreesha3112 <shreesha3112.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent d9212de commit 876c235

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,10 @@ def _check_function_param_validation(
307307
"sklearn.tree.export_text",
308308
"sklearn.tree.plot_tree",
309309
"sklearn.utils.gen_batches",
310-
"sklearn.utils.graph.single_source_shortest_path_length",
311310
"sklearn.utils.resample",
312311
"sklearn.utils.safe_mask",
312+
"sklearn.utils.extmath.randomized_svd",
313+
"sklearn.utils.graph.single_source_shortest_path_length",
313314
]
314315

315316

sklearn/utils/extmath.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# License: BSD 3 clause
1313

1414
import warnings
15+
from numbers import Integral
1516

1617
import numpy as np
1718
from scipy import linalg, sparse
1819

20+
from ..utils._param_validation import Interval, StrOptions, validate_params
1921
from . import check_random_state
2022
from ._array_api import _is_numpy_namespace, get_namespace
2123
from ._logistic_sigmoid import _log_logistic_sigmoid
@@ -288,6 +290,20 @@ def randomized_range_finder(
288290
return Q
289291

290292

293+
@validate_params(
294+
{
295+
"M": [np.ndarray, "sparse matrix"],
296+
"n_components": [Interval(Integral, 1, None, closed="left")],
297+
"n_oversamples": [Interval(Integral, 0, None, closed="left")],
298+
"n_iter": [Interval(Integral, 0, None, closed="left"), StrOptions({"auto"})],
299+
"power_iteration_normalizer": [StrOptions({"auto", "QR", "LU", "none"})],
300+
"transpose": ["boolean", StrOptions({"auto"})],
301+
"flip_sign": ["boolean"],
302+
"random_state": ["random_state"],
303+
"svd_lapack_driver": [StrOptions({"gesdd", "gesvd"})],
304+
},
305+
prefer_skip_nested_validation=True,
306+
)
291307
def randomized_svd(
292308
M,
293309
n_components,
@@ -314,9 +330,9 @@ def randomized_svd(
314330
Number of singular values and vectors to extract.
315331
316332
n_oversamples : int, default=10
317-
Additional number of random vectors to sample the range of M so as
333+
Additional number of random vectors to sample the range of `M` so as
318334
to ensure proper conditioning. The total number of random vectors
319-
used to find the range of M is n_components + n_oversamples. Smaller
335+
used to find the range of `M` is `n_components + n_oversamples`. Smaller
320336
number can improve speed but can negatively impact the quality of
321337
approximation of singular vectors and singular values. Users might wish
322338
to increase this parameter up to `2*k - n_components` where k is the

0 commit comments

Comments
 (0)