Skip to content

Commit 9ab298a

Browse files
MAINT Parameters validation for sklearn.isotonic.isotonic_regression (scikit-learn#26257)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 7f871fe commit 9ab298a

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

sklearn/isotonic.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ._isotonic import _inplace_contiguous_isotonic_regression, _make_unique
1515
from .base import BaseEstimator, RegressorMixin, TransformerMixin, _fit_context
1616
from .utils import check_array, check_consistent_length
17-
from .utils._param_validation import Interval, StrOptions
17+
from .utils._param_validation import Interval, StrOptions, validate_params
1818
from .utils.validation import _check_sample_weight, check_is_fitted
1919

2020
__all__ = ["check_increasing", "isotonic_regression", "IsotonicRegression"]
@@ -79,6 +79,16 @@ def check_increasing(x, y):
7979
return increasing_bool
8080

8181

82+
@validate_params(
83+
{
84+
"y": ["array-like"],
85+
"sample_weight": ["array-like", None],
86+
"y_min": [Interval(Real, None, None, closed="both"), None],
87+
"y_max": [Interval(Real, None, None, closed="both"), None],
88+
"increasing": ["boolean"],
89+
},
90+
prefer_skip_nested_validation=True,
91+
)
8292
def isotonic_regression(
8393
y, *, sample_weight=None, y_min=None, y_max=None, increasing=True
8494
):

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def _check_function_param_validation(
178178
"sklearn.feature_selection.r_regression",
179179
"sklearn.inspection.partial_dependence",
180180
"sklearn.inspection.permutation_importance",
181+
"sklearn.isotonic.isotonic_regression",
181182
"sklearn.linear_model.orthogonal_mp",
182183
"sklearn.linear_model.ridge_regression",
183184
"sklearn.metrics.accuracy_score",

0 commit comments

Comments
 (0)