Skip to content

Commit 1fae098

Browse files
FIX: Change limits of power_t param to [0, inf) (scikit-learn#31474)
1 parent ab3d34e commit 1fae098

File tree

3 files changed

+76
-3
lines changed

3 files changed

+76
-3
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- :class:`linear_model.SGDClassifier`, :class:`linear_model.SGDRegressor`, and
2+
:class:`linear_model.SGDOneClassSVM` now deprecate negative values for the
3+
`power_t` parameter. Using a negative value will raise a warning in version 1.8
4+
and will raise an error in version 1.10. A value in the range [0.0, inf) must be used
5+
instead.
6+
By :user:`Ritvi Alagusankar <ritvi-alagusankar>`

sklearn/linear_model/_stochastic_gradient.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,15 @@ def _fit(
731731
),
732732
ConvergenceWarning,
733733
)
734+
735+
if self.power_t < 0:
736+
warnings.warn(
737+
"Negative values for `power_t` are deprecated in version 1.8 "
738+
"and will raise an error in 1.10. "
739+
"Use values in the range [0.0, inf) instead.",
740+
FutureWarning,
741+
)
742+
734743
return self
735744

736745
def _fit_binary(self, X, y, alpha, C, sample_weight, learning_rate, max_iter):
@@ -1082,7 +1091,11 @@ class SGDClassifier(BaseSGDClassifier):
10821091
10831092
power_t : float, default=0.5
10841093
The exponent for inverse scaling learning rate.
1085-
Values must be in the range `(-inf, inf)`.
1094+
Values must be in the range `[0.0, inf)`.
1095+
1096+
.. deprecated:: 1.8
1097+
Negative values for `power_t` are deprecated in version 1.8 and will raise
1098+
an error in 1.10. Use values in the range [0.0, inf) instead.
10861099
10871100
early_stopping : bool, default=False
10881101
Whether to use early stopping to terminate training when validation
@@ -1585,6 +1598,14 @@ def _fit(
15851598
ConvergenceWarning,
15861599
)
15871600

1601+
if self.power_t < 0:
1602+
warnings.warn(
1603+
"Negative values for `power_t` are deprecated in version 1.8 "
1604+
"and will raise an error in 1.10. "
1605+
"Use values in the range [0.0, inf) instead.",
1606+
FutureWarning,
1607+
)
1608+
15881609
return self
15891610

15901611
@_fit_context(prefer_skip_nested_validation=True)
@@ -1880,7 +1901,11 @@ class SGDRegressor(BaseSGDRegressor):
18801901
18811902
power_t : float, default=0.25
18821903
The exponent for inverse scaling learning rate.
1883-
Values must be in the range `(-inf, inf)`.
1904+
Values must be in the range `[0.0, inf)`.
1905+
1906+
.. deprecated:: 1.8
1907+
Negative values for `power_t` are deprecated in version 1.8 and will raise
1908+
an error in 1.10. Use values in the range [0.0, inf) instead.
18841909
18851910
early_stopping : bool, default=False
18861911
Whether to use early stopping to terminate training when validation
@@ -2118,7 +2143,11 @@ class SGDOneClassSVM(OutlierMixin, BaseSGD):
21182143
21192144
power_t : float, default=0.5
21202145
The exponent for inverse scaling learning rate.
2121-
Values must be in the range `(-inf, inf)`.
2146+
Values must be in the range `[0.0, inf)`.
2147+
2148+
.. deprecated:: 1.8
2149+
Negative values for `power_t` are deprecated in version 1.8 and will raise
2150+
an error in 1.10. Use values in the range [0.0, inf) instead.
21222151
21232152
warm_start : bool, default=False
21242153
When set to True, reuse the solution of the previous call to fit as
@@ -2490,6 +2519,14 @@ def _fit(
24902519
ConvergenceWarning,
24912520
)
24922521

2522+
if self.power_t < 0:
2523+
warnings.warn(
2524+
"Negative values for `power_t` are deprecated in version 1.8 "
2525+
"and will raise an error in 1.10. "
2526+
"Use values in the range [0.0, inf) instead.",
2527+
FutureWarning,
2528+
)
2529+
24932530
return self
24942531

24952532
@_fit_context(prefer_skip_nested_validation=True)

sklearn/linear_model/tests/test_sgd.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pickle
2+
import warnings
23
from unittest.mock import Mock
34

45
import joblib
@@ -507,6 +508,35 @@ def test_sgd_failing_penalty_validation(Estimator):
507508
clf.fit(X, Y)
508509

509510

511+
# TODO(1.10): remove this test
512+
@pytest.mark.parametrize(
513+
"klass",
514+
[
515+
SGDClassifier,
516+
SparseSGDClassifier,
517+
SGDRegressor,
518+
SparseSGDRegressor,
519+
SGDOneClassSVM,
520+
SparseSGDOneClassSVM,
521+
],
522+
)
523+
def test_power_t_limits(klass):
524+
"""Check that a warning is raised when `power_t` is negative."""
525+
526+
# Check that negative values of `power_t` raise a warning
527+
clf = klass(power_t=-1.0)
528+
with pytest.warns(
529+
FutureWarning, match="Negative values for `power_t` are deprecated"
530+
):
531+
clf.fit(X, Y)
532+
533+
# Check that values of 'power_t in range [0, inf) do not raise a warning
534+
with warnings.catch_warnings(record=True) as w:
535+
clf = klass(power_t=0.5)
536+
clf.fit(X, Y)
537+
assert len(w) == 0
538+
539+
510540
###############################################################################
511541
# Classification Test Case
512542

0 commit comments

Comments
 (0)