Skip to content

Commit dda6337

Browse files
ENH scaling of LogisticRegression loss as 1/n * LinearModelLoss (scikit-learn#26721)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 2105500 commit dda6337

File tree

10 files changed

+181
-77
lines changed

10 files changed

+181
-77
lines changed

doc/whats_new/v1.4.rst

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@ parameters, may produce different models from the previous version. This often
1919
occurs due to changes in the modelling logic (bug fixes or enhancements), or in
2020
random sampling procedures.
2121

22+
- |Efficiency| :class:`linear_model.LogisticRegression` and
23+
:class:`linear_model.LogisticRegressionCV` now have much better convergence for
24+
solvers `"lbfgs"` and `"newton-cg"`. Both solvers can now reach much higher precision
25+
for the coefficients depending on the specified `tol`. Additionally, lbfgs can
26+
make better use of `tol`, i.e., stop sooner or reach higher precision.
27+
:pr:`26721` by :user:`Christian Lorentzen <lorentzenchr>`.
28+
29+
.. note::
30+
31+
The lbfgs is the default solver, so this change might effect many models.
32+
33+
This change also means that with this new version of scikit-learn, the resulting
34+
coefficients `coef_` and `intercept_` of your models will change for these two
35+
solvers (when fit on the same data again). The amount of change depends on the
36+
specified `tol`, for small values you will get more precise results.
37+
2238
Changes impacting all modules
2339
-----------------------------
2440

@@ -257,7 +273,30 @@ Changelog
257273
:mod:`sklearn.linear_model`
258274
...........................
259275

260-
- |Enhancement| Solver `"newton-cg"` in :class:`linear_model.LogisticRegression` and
276+
- |Efficiency| :class:`linear_model.LogisticRegression` and
277+
:class:`linear_model.LogisticRegressionCV` now have much better convergence for
278+
solvers `"lbfgs"` and `"newton-cg"`. Both solvers can now reach much higher precision
279+
for the coefficients depending on the specified `tol`. Additionally, lbfgs can
280+
make better use of `tol`, i.e., stop sooner or reach higher precision. This is
281+
accomplished by better scaling of the objective function, i.e., using average per
282+
sample losses instead of sum of per sample losses.
283+
:pr:`26721` by :user:`Christian Lorentzen <lorentzenchr>`.
284+
285+
.. note::
286+
287+
This change also means that with this new version of scikit-learn, the resulting
288+
coefficients `coef_` and `intercept_` of your models will change for these two
289+
solvers (when fit on the same data again). The amount of change depends on the
290+
specified `tol`, for small values you will get more precise results.
291+
292+
- |Efficiency| :class:`linear_model.LogisticRegression` and
293+
:class:`linear_model.LogisticRegressionCV` with solver `"newton-cg"` can now be
294+
considerably faster for some data and parameter settings. This is accomplished by a
295+
better line search convergence check for negligible loss improvements that takes into
296+
account gradient information.
297+
:pr:`26721` by :user:`Christian Lorentzen <lorentzenchr>`.
298+
299+
- |Efficiency| Solver `"newton-cg"` in :class:`linear_model.LogisticRegression` and
261300
:class:`linear_model.LogisticRegressionCV` uses a little less memory. The effect is
262301
proportional to the number of coefficients (`n_features * n_classes`).
263302
:pr:`27417` by :user:`Christian Lorentzen <lorentzenchr>`.

sklearn/feature_selection/_from_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ class SelectFromModel(
211211
>>> y = [0, 1, 0, 1]
212212
>>> selector = SelectFromModel(estimator=LogisticRegression()).fit(X, y)
213213
>>> selector.estimator_.coef_
214-
array([[-0.3252302 , 0.83462377, 0.49750423]])
214+
array([[-0.3252..., 0.8345..., 0.4976...]])
215215
>>> selector.threshold_
216-
0.55245...
216+
0.55249...
217217
>>> selector.get_support()
218218
array([False, True, False])
219219
>>> selector.transform(X)

sklearn/linear_model/_glm/glm.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,10 @@ def fit(self, X, y, sample_weight=None):
207207
loss_dtype = min(max(y.dtype, X.dtype), np.float64)
208208
y = check_array(y, dtype=loss_dtype, order="C", ensure_2d=False)
209209

210-
# TODO: We could support samples_weight=None as the losses support it.
211-
# Note that _check_sample_weight calls check_array(order="C") required by
212-
# losses.
213-
sample_weight = _check_sample_weight(sample_weight, X, dtype=loss_dtype)
210+
if sample_weight is not None:
211+
# Note that _check_sample_weight calls check_array(order="C") required by
212+
# losses.
213+
sample_weight = _check_sample_weight(sample_weight, X, dtype=loss_dtype)
214214

215215
n_samples, n_features = X.shape
216216
self._base_loss = self._get_loss()
@@ -228,17 +228,20 @@ def fit(self, X, y, sample_weight=None):
228228

229229
# TODO: if alpha=0 check that X is not rank deficient
230230

231-
# IMPORTANT NOTE: Rescaling of sample_weight:
231+
# NOTE: Rescaling of sample_weight:
232232
# We want to minimize
233-
# obj = 1/(2*sum(sample_weight)) * sum(sample_weight * deviance)
233+
# obj = 1/(2 * sum(sample_weight)) * sum(sample_weight * deviance)
234234
# + 1/2 * alpha * L2,
235235
# with
236236
# deviance = 2 * loss.
237237
# The objective is invariant to multiplying sample_weight by a constant. We
238-
# choose this constant such that sum(sample_weight) = 1. Thus, we end up with
238+
# could choose this constant such that sum(sample_weight) = 1 in order to end
239+
# up with
239240
# obj = sum(sample_weight * loss) + 1/2 * alpha * L2.
240-
# Note that LinearModelLoss.loss() computes sum(sample_weight * loss).
241-
sample_weight = sample_weight / sample_weight.sum()
241+
# But LinearModelLoss.loss() already computes
242+
# average(loss, weights=sample_weight)
243+
# Thus, without rescaling, we have
244+
# obj = LinearModelLoss.loss(...)
242245

243246
if self.warm_start and hasattr(self, "coef_"):
244247
if self.fit_intercept:
@@ -415,10 +418,10 @@ def score(self, X, y, sample_weight=None):
415418
f" {base_loss.__name__}."
416419
)
417420

418-
# Note that constant_to_optimal_zero is already multiplied by sample_weight.
419-
constant = np.mean(base_loss.constant_to_optimal_zero(y_true=y))
420-
if sample_weight is not None:
421-
constant *= sample_weight.shape[0] / np.sum(sample_weight)
421+
constant = np.average(
422+
base_loss.constant_to_optimal_zero(y_true=y, sample_weight=None),
423+
weights=sample_weight,
424+
)
422425

423426
# Missing factor of 2 in deviance cancels out.
424427
deviance = base_loss(

sklearn/linear_model/_linear_loss.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,19 @@ class LinearModelLoss:
1212
1313
Note that raw_prediction is also known as linear predictor.
1414
15-
The loss is the sum of per sample losses and includes a term for L2
15+
The loss is the average of per sample losses and includes a term for L2
1616
regularization::
1717
18-
loss = sum_i s_i loss(y_i, X_i @ coef + intercept)
18+
loss = 1 / s_sum * sum_i s_i loss(y_i, X_i @ coef + intercept)
1919
+ 1/2 * l2_reg_strength * ||coef||_2^2
2020
21-
with sample weights s_i=1 if sample_weight=None.
21+
with sample weights s_i=1 if sample_weight=None and s_sum=sum_i s_i.
2222
2323
Gradient and hessian, for simplicity without intercept, are::
2424
25-
gradient = X.T @ loss.gradient + l2_reg_strength * coef
26-
hessian = X.T @ diag(loss.hessian) @ X + l2_reg_strength * identity
25+
gradient = 1 / s_sum * X.T @ loss.gradient + l2_reg_strength * coef
26+
hessian = 1 / s_sum * X.T @ diag(loss.hessian) @ X
27+
+ l2_reg_strength * identity
2728
2829
Conventions:
2930
if fit_intercept:
@@ -182,7 +183,7 @@ def loss(
182183
n_threads=1,
183184
raw_prediction=None,
184185
):
185-
"""Compute the loss as sum over point-wise losses.
186+
"""Compute the loss as weighted average over point-wise losses.
186187
187188
Parameters
188189
----------
@@ -209,7 +210,7 @@ def loss(
209210
Returns
210211
-------
211212
loss : float
212-
Sum of losses per sample plus penalty.
213+
Weighted average of losses per sample, plus penalty.
213214
"""
214215
if raw_prediction is None:
215216
weights, intercept, raw_prediction = self.weight_intercept_raw(coef, X)
@@ -219,10 +220,10 @@ def loss(
219220
loss = self.base_loss.loss(
220221
y_true=y,
221222
raw_prediction=raw_prediction,
222-
sample_weight=sample_weight,
223+
sample_weight=None,
223224
n_threads=n_threads,
224225
)
225-
loss = loss.sum()
226+
loss = np.average(loss, weights=sample_weight)
226227

227228
return loss + self.l2_penalty(weights, l2_reg_strength)
228229

@@ -263,12 +264,12 @@ def loss_gradient(
263264
Returns
264265
-------
265266
loss : float
266-
Sum of losses per sample plus penalty.
267+
Weighted average of losses per sample, plus penalty.
267268
268269
gradient : ndarray of shape coef.shape
269270
The gradient of the loss.
270271
"""
271-
n_features, n_classes = X.shape[1], self.base_loss.n_classes
272+
(n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes
272273
n_dof = n_features + int(self.fit_intercept)
273274

274275
if raw_prediction is None:
@@ -282,9 +283,12 @@ def loss_gradient(
282283
sample_weight=sample_weight,
283284
n_threads=n_threads,
284285
)
285-
loss = loss.sum()
286+
sw_sum = n_samples if sample_weight is None else np.sum(sample_weight)
287+
loss = loss.sum() / sw_sum
286288
loss += self.l2_penalty(weights, l2_reg_strength)
287289

290+
grad_pointwise /= sw_sum
291+
288292
if not self.base_loss.is_multiclass:
289293
grad = np.empty_like(coef, dtype=weights.dtype)
290294
grad[:n_features] = X.T @ grad_pointwise + l2_reg_strength * weights
@@ -340,7 +344,7 @@ def gradient(
340344
gradient : ndarray of shape coef.shape
341345
The gradient of the loss.
342346
"""
343-
n_features, n_classes = X.shape[1], self.base_loss.n_classes
347+
(n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes
344348
n_dof = n_features + int(self.fit_intercept)
345349

346350
if raw_prediction is None:
@@ -354,6 +358,8 @@ def gradient(
354358
sample_weight=sample_weight,
355359
n_threads=n_threads,
356360
)
361+
sw_sum = n_samples if sample_weight is None else np.sum(sample_weight)
362+
grad_pointwise /= sw_sum
357363

358364
if not self.base_loss.is_multiclass:
359365
grad = np.empty_like(coef, dtype=weights.dtype)
@@ -439,6 +445,9 @@ def gradient_hessian(
439445
sample_weight=sample_weight,
440446
n_threads=n_threads,
441447
)
448+
sw_sum = n_samples if sample_weight is None else np.sum(sample_weight)
449+
grad_pointwise /= sw_sum
450+
hess_pointwise /= sw_sum
442451

443452
# For non-canonical link functions and far away from the optimum, the pointwise
444453
# hessian can be negative. We take care that 75% of the hessian entries are
@@ -543,6 +552,7 @@ def gradient_hessian_product(
543552
(n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes
544553
n_dof = n_features + int(self.fit_intercept)
545554
weights, intercept, raw_prediction = self.weight_intercept_raw(coef, X)
555+
sw_sum = n_samples if sample_weight is None else np.sum(sample_weight)
546556

547557
if not self.base_loss.is_multiclass:
548558
grad_pointwise, hess_pointwise = self.base_loss.gradient_hessian(
@@ -551,6 +561,8 @@ def gradient_hessian_product(
551561
sample_weight=sample_weight,
552562
n_threads=n_threads,
553563
)
564+
grad_pointwise /= sw_sum
565+
hess_pointwise /= sw_sum
554566
grad = np.empty_like(coef, dtype=weights.dtype)
555567
grad[:n_features] = X.T @ grad_pointwise + l2_reg_strength * weights
556568
if self.fit_intercept:
@@ -603,6 +615,7 @@ def hessp(s):
603615
sample_weight=sample_weight,
604616
n_threads=n_threads,
605617
)
618+
grad_pointwise /= sw_sum
606619
grad = np.empty((n_classes, n_dof), dtype=weights.dtype, order="F")
607620
grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
608621
if self.fit_intercept:
@@ -644,9 +657,9 @@ def hessp(s):
644657
# hess_prod = empty_like(grad), but we ravel grad below and this
645658
# function is run after that.
646659
hess_prod = np.empty((n_classes, n_dof), dtype=weights.dtype, order="F")
647-
hess_prod[:, :n_features] = tmp.T @ X + l2_reg_strength * s
660+
hess_prod[:, :n_features] = (tmp.T @ X) / sw_sum + l2_reg_strength * s
648661
if self.fit_intercept:
649-
hess_prod[:, -1] = tmp.sum(axis=0)
662+
hess_prod[:, -1] = tmp.sum(axis=0) / sw_sum
650663
if coef.ndim == 1:
651664
return hess_prod.ravel(order="F")
652665
else:

sklearn/linear_model/_logistic.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -304,33 +304,16 @@ def _logistic_regression_path(
304304
# np.unique(y) gives labels in sorted order.
305305
pos_class = classes[1]
306306

307-
# If sample weights exist, convert them to array (support for lists)
308-
# and check length
309-
# Otherwise set them to 1 for all examples
310-
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype, copy=True)
311-
312-
if solver == "newton-cholesky":
313-
# IMPORTANT NOTE: Rescaling of sample_weight:
314-
# Same as in _GeneralizedLinearRegressor.fit().
315-
# We want to minimize
316-
# obj = 1/(2*sum(sample_weight)) * sum(sample_weight * deviance)
317-
# + 1/2 * alpha * L2,
318-
# with
319-
# deviance = 2 * log_loss.
320-
# The objective is invariant to multiplying sample_weight by a constant. We
321-
# choose this constant such that sum(sample_weight) = 1. Thus, we end up with
322-
# obj = sum(sample_weight * loss) + 1/2 * alpha * L2.
323-
# Note that LinearModelLoss.loss() computes sum(sample_weight * loss).
324-
#
325-
# This rescaling has to be done before multiplying by class_weights.
326-
sw_sum = sample_weight.sum() # needed to rescale penalty, nasty matter!
327-
sample_weight = sample_weight / sw_sum
307+
if sample_weight is not None or class_weight is not None:
308+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype, copy=True)
328309

329310
# If class_weights is a dict (provided by the user), the weights
330311
# are assigned to the original labels. If it is "balanced", then
331312
# the class_weights are assigned after masking the labels with a OvR.
332313
le = LabelEncoder()
333-
if isinstance(class_weight, dict) or multi_class == "multinomial":
314+
if isinstance(class_weight, dict) or (
315+
multi_class == "multinomial" and class_weight is not None
316+
):
334317
class_weight_ = compute_class_weight(class_weight, classes=classes, y=y)
335318
sample_weight *= class_weight_[le.fit_transform(y)]
336319

@@ -375,6 +358,19 @@ def _logistic_regression_path(
375358
(classes.size, n_features + int(fit_intercept)), order="F", dtype=X.dtype
376359
)
377360

361+
# IMPORTANT NOTE:
362+
# All solvers relying on LinearModelLoss need to scale the penalty with n_samples
363+
# or the sum of sample weights because the implemented logistic regression
364+
# objective here is (unfortunately)
365+
# C * sum(pointwise_loss) + penalty
366+
# instead of (as LinearModelLoss does)
367+
# mean(pointwise_loss) + 1/C * penalty
368+
if solver in ["lbfgs", "newton-cg", "newton-cholesky"]:
369+
# This needs to be calculated after sample_weight is multiplied by
370+
# class_weight. It is even tested that passing class_weight is equivalent to
371+
# passing sample_weights according to class_weight.
372+
sw_sum = n_samples if sample_weight is None else np.sum(sample_weight)
373+
378374
if coef is not None:
379375
# it must work both giving the bias term and not
380376
if multi_class == "ovr":
@@ -457,7 +453,7 @@ def _logistic_regression_path(
457453
n_iter = np.zeros(len(Cs), dtype=np.int32)
458454
for i, C in enumerate(Cs):
459455
if solver == "lbfgs":
460-
l2_reg_strength = 1.0 / C
456+
l2_reg_strength = 1.0 / (C * sw_sum)
461457
iprint = [-1, 50, 1, 100, 101][
462458
np.searchsorted(np.array([0, 1, 2, 3]), verbose)
463459
]
@@ -467,7 +463,13 @@ def _logistic_regression_path(
467463
method="L-BFGS-B",
468464
jac=True,
469465
args=(X, target, sample_weight, l2_reg_strength, n_threads),
470-
options={"iprint": iprint, "gtol": tol, "maxiter": max_iter},
466+
options={
467+
"maxiter": max_iter,
468+
"maxls": 50, # default is 20
469+
"iprint": iprint,
470+
"gtol": tol,
471+
"ftol": 64 * np.finfo(float).eps,
472+
},
471473
)
472474
n_iter_i = _check_optimize_result(
473475
solver,
@@ -477,15 +479,13 @@ def _logistic_regression_path(
477479
)
478480
w0, loss = opt_res.x, opt_res.fun
479481
elif solver == "newton-cg":
480-
l2_reg_strength = 1.0 / C
482+
l2_reg_strength = 1.0 / (C * sw_sum)
481483
args = (X, target, sample_weight, l2_reg_strength, n_threads)
482484
w0, n_iter_i = _newton_cg(
483485
hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol
484486
)
485487
elif solver == "newton-cholesky":
486-
# The division by sw_sum is a consequence of the rescaling of
487-
# sample_weight, see comment above.
488-
l2_reg_strength = 1.0 / C / sw_sum
488+
l2_reg_strength = 1.0 / (C * sw_sum)
489489
sol = NewtonCholeskySolver(
490490
coef=w0,
491491
linear_loss=loss,

sklearn/linear_model/tests/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
),
6060
marks=pytest.mark.xfail(reason="Missing importance sampling scheme"),
6161
),
62-
LogisticRegressionCV(),
62+
LogisticRegressionCV(tol=1e-6),
6363
MultiTaskElasticNet(),
6464
MultiTaskElasticNetCV(),
6565
MultiTaskLasso(),

0 commit comments

Comments
 (0)