Skip to content

Commit 81f62ab

Browse files
MAINT Parameters validation for sklearn.linear_model.orthogonal_mp_gram (scikit-learn#26382)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 315a04c commit 81f62ab

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

sklearn/linear_model/_omp.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,20 @@ def orthogonal_mp(
447447
return np.squeeze(coef)
448448

449449

450+
@validate_params(
451+
{
452+
"Gram": ["array-like"],
453+
"Xy": ["array-like"],
454+
"n_nonzero_coefs": [Interval(Integral, 0, None, closed="neither"), None],
455+
"tol": [Interval(Real, 0, None, closed="left"), None],
456+
"norms_squared": ["array-like", None],
457+
"copy_Gram": ["boolean"],
458+
"copy_Xy": ["boolean"],
459+
"return_path": ["boolean"],
460+
"return_n_iter": ["boolean"],
461+
},
462+
prefer_skip_nested_validation=True,
463+
)
450464
def orthogonal_mp_gram(
451465
Gram,
452466
Xy,
@@ -468,30 +482,30 @@ def orthogonal_mp_gram(
468482
469483
Parameters
470484
----------
471-
Gram : ndarray of shape (n_features, n_features)
472-
Gram matrix of the input data: X.T * X.
485+
Gram : array-like of shape (n_features, n_features)
486+
Gram matrix of the input data: `X.T * X`.
473487
474-
Xy : ndarray of shape (n_features,) or (n_features, n_targets)
475-
Input targets multiplied by X: X.T * y.
488+
Xy : array-like of shape (n_features,) or (n_features, n_targets)
489+
Input targets multiplied by `X`: `X.T * y`.
476490
477491
n_nonzero_coefs : int, default=None
478-
Desired number of non-zero entries in the solution. If None (by
492+
Desired number of non-zero entries in the solution. If `None` (by
479493
default) this value is set to 10% of n_features.
480494
481495
tol : float, default=None
482-
Maximum norm of the residual. If not None, overrides n_nonzero_coefs.
496+
Maximum norm of the residual. If not `None`, overrides `n_nonzero_coefs`.
483497
484498
norms_squared : array-like of shape (n_targets,), default=None
485-
Squared L2 norms of the lines of y. Required if tol is not None.
499+
Squared L2 norms of the lines of `y`. Required if `tol` is not None.
486500
487501
copy_Gram : bool, default=True
488-
Whether the gram matrix must be copied by the algorithm. A false
502+
Whether the gram matrix must be copied by the algorithm. A `False`
489503
value is only helpful if it is already Fortran-ordered, otherwise a
490504
copy is made anyway.
491505
492506
copy_Xy : bool, default=True
493-
Whether the covariance vector Xy must be copied by the algorithm.
494-
If False, it may be overwritten.
507+
Whether the covariance vector `Xy` must be copied by the algorithm.
508+
If `False`, it may be overwritten.
495509
496510
return_path : bool, default=False
497511
Whether to return every value of the nonzero coefficients along the
@@ -505,11 +519,11 @@ def orthogonal_mp_gram(
505519
coef : ndarray of shape (n_features,) or (n_features, n_targets)
506520
Coefficients of the OMP solution. If `return_path=True`, this contains
507521
the whole coefficient path. In this case its shape is
508-
(n_features, n_features) or (n_features, n_targets, n_features) and
522+
`(n_features, n_features)` or `(n_features, n_targets, n_features)` and
509523
iterating over the last axis yields coefficients in increasing order
510524
of active features.
511525
512-
n_iters : array-like or int
526+
n_iters : list or int
513527
Number of active features across every target. Returned only if
514528
`return_n_iter` is set to True.
515529

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def _check_function_param_validation(
199199
"sklearn.inspection.permutation_importance",
200200
"sklearn.isotonic.isotonic_regression",
201201
"sklearn.linear_model.orthogonal_mp",
202+
"sklearn.linear_model.orthogonal_mp_gram",
202203
"sklearn.linear_model.ridge_regression",
203204
"sklearn.metrics.accuracy_score",
204205
"sklearn.metrics.auc",

0 commit comments

Comments
 (0)