Skip to content

Commit c08b433

Browse files
ENH multiclass/multinomial newton cholesky for LogisticRegression (scikit-learn#28840)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 18dc863 commit c08b433

File tree

7 files changed

+622
-161
lines changed

7 files changed

+622
-161
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- The `solver="newton-cholesky"` in
2+
:class:`linear_model.LogisticRegression` and
3+
:class:`linear_model.LogisticRegressionCV` is extended to support the full
4+
multinomial loss in a multiclass setting.
5+
By :user:`Christian Lorentzen <lorentzenchr>`.

sklearn/_loss/tests/test_loss.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,52 +419,60 @@ def test_loss_dtype(
419419
if sample_weight is not None:
420420
sample_weight = create_memmap_backed_data(sample_weight)
421421

422-
loss.loss(
422+
l = loss.loss(
423423
y_true=y_true,
424424
raw_prediction=raw_prediction,
425425
sample_weight=sample_weight,
426426
loss_out=out1,
427427
n_threads=n_threads,
428428
)
429-
loss.gradient(
429+
assert l is out1 if out1 is not None else True
430+
g = loss.gradient(
430431
y_true=y_true,
431432
raw_prediction=raw_prediction,
432433
sample_weight=sample_weight,
433434
gradient_out=out2,
434435
n_threads=n_threads,
435436
)
436-
loss.loss_gradient(
437+
assert g is out2 if out2 is not None else True
438+
l, g = loss.loss_gradient(
437439
y_true=y_true,
438440
raw_prediction=raw_prediction,
439441
sample_weight=sample_weight,
440442
loss_out=out1,
441443
gradient_out=out2,
442444
n_threads=n_threads,
443445
)
446+
assert l is out1 if out1 is not None else True
447+
assert g is out2 if out2 is not None else True
444448
if out1 is not None and loss.is_multiclass:
445449
out1 = np.empty_like(raw_prediction, dtype=dtype_out)
446-
loss.gradient_hessian(
450+
g, h = loss.gradient_hessian(
447451
y_true=y_true,
448452
raw_prediction=raw_prediction,
449453
sample_weight=sample_weight,
450454
gradient_out=out1,
451455
hessian_out=out2,
452456
n_threads=n_threads,
453457
)
458+
assert g is out1 if out1 is not None else True
459+
assert h is out2 if out2 is not None else True
454460
loss(y_true=y_true, raw_prediction=raw_prediction, sample_weight=sample_weight)
455461
loss.fit_intercept_only(y_true=y_true, sample_weight=sample_weight)
456462
loss.constant_to_optimal_zero(y_true=y_true, sample_weight=sample_weight)
457463
if hasattr(loss, "predict_proba"):
458464
loss.predict_proba(raw_prediction=raw_prediction)
459465
if hasattr(loss, "gradient_proba"):
460-
loss.gradient_proba(
466+
g, p = loss.gradient_proba(
461467
y_true=y_true,
462468
raw_prediction=raw_prediction,
463469
sample_weight=sample_weight,
464470
gradient_out=out1,
465471
proba_out=out2,
466472
n_threads=n_threads,
467473
)
474+
assert g is out1 if out1 is not None else True
475+
assert p is out2 if out2 is not None else True
468476

469477

470478
@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)

sklearn/linear_model/_glm/_newton_solver.py

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@ def line_search(self, X, y, sample_weight):
298298
return
299299

300300
self.raw_prediction = raw
301+
if is_verbose:
302+
print(
303+
f" line search successful after {i+1} iterations with "
304+
f"loss={self.loss_value}."
305+
)
301306

302307
def check_convergence(self, X, y, sample_weight):
303308
"""Check for convergence.
@@ -310,24 +315,27 @@ def check_convergence(self, X, y, sample_weight):
310315
# convergence criterion because even a large step could have brought us close
311316
# to the true minimum.
312317
# coef_step = self.coef - self.coef_old
313-
# check = np.max(np.abs(coef_step) / np.maximum(1, np.abs(self.coef_old)))
318+
# change = np.max(np.abs(coef_step) / np.maximum(1, np.abs(self.coef_old)))
319+
# check = change <= tol
314320

315321
# 1. Criterion: maximum |gradient| <= tol
316322
# The gradient was already updated in line_search()
317-
check = np.max(np.abs(self.gradient))
323+
g_max_abs = np.max(np.abs(self.gradient))
324+
check = g_max_abs <= self.tol
318325
if self.verbose:
319-
print(f" 1. max |gradient| {check} <= {self.tol}")
320-
if check > self.tol:
326+
print(f" 1. max |gradient| {g_max_abs} <= {self.tol} {check}")
327+
if not check:
321328
return
322329

323330
# 2. Criterion: For Newton decrement d, check 1/2 * d^2 <= tol
324331
# d = sqrt(grad @ hessian^-1 @ grad)
325332
# = sqrt(coef_newton @ hessian @ coef_newton)
326333
# See Boyd, Vanderberghe (2009) "Convex Optimization" Chapter 9.5.1.
327334
d2 = self.coef_newton @ self.hessian @ self.coef_newton
335+
check = 0.5 * d2 <= self.tol
328336
if self.verbose:
329-
print(f" 2. Newton decrement {0.5 * d2} <= {self.tol}")
330-
if 0.5 * d2 > self.tol:
337+
print(f" 2. Newton decrement {0.5 * d2} <= {self.tol} {check}")
338+
if not check:
331339
return
332340

333341
if self.verbose:
@@ -442,11 +450,23 @@ class NewtonCholeskySolver(NewtonSolver):
442450

443451
def setup(self, X, y, sample_weight):
444452
super().setup(X=X, y=y, sample_weight=sample_weight)
445-
n_dof = X.shape[1]
446-
if self.linear_loss.fit_intercept:
447-
n_dof += 1
453+
if self.linear_loss.base_loss.is_multiclass:
454+
# Easier with ravelled arrays, e.g., for scipy.linalg.solve.
455+
# As with LinearModelLoss, we always are contiguous in n_classes.
456+
self.coef = self.coef.ravel(order="F")
457+
# Note that the computation of gradient in LinearModelLoss follows the shape of
458+
# coef.
448459
self.gradient = np.empty_like(self.coef)
449-
self.hessian = np.empty_like(self.coef, shape=(n_dof, n_dof))
460+
# But the hessian is always 2d.
461+
n = self.coef.size
462+
self.hessian = np.empty_like(self.coef, shape=(n, n))
463+
# To help case distinctions.
464+
self.is_multinomial_with_intercept = (
465+
self.linear_loss.base_loss.is_multiclass and self.linear_loss.fit_intercept
466+
)
467+
self.is_multinomial_no_penalty = (
468+
self.linear_loss.base_loss.is_multiclass and self.l2_reg_strength == 0
469+
)
450470

451471
def update_gradient_hessian(self, X, y, sample_weight):
452472
_, _, self.hessian_warning = self.linear_loss.gradient_hessian(
@@ -479,12 +499,70 @@ def inner_solve(self, X, y, sample_weight):
479499
self.use_fallback_lbfgs_solve = True
480500
return
481501

502+
# Note: The following case distinction could also be shifted to the
503+
# implementation of HalfMultinomialLoss instead of here within the solver.
504+
if self.is_multinomial_no_penalty:
505+
# The multinomial loss is overparametrized for each unpenalized feature, so
506+
# at least the intercepts. This can be seen by noting that predicted
507+
# probabilities are invariant under shifting all coefficients of a single
508+
# feature j for all classes by the same amount c:
509+
# coef[k, :] -> coef[k, :] + c => proba stays the same
510+
# where we have assumned coef.shape = (n_classes, n_features).
511+
# Therefore, also the loss (-log-likelihood), gradient and hessian stay the
512+
# same, see
513+
# Noah Simon and Jerome Friedman and Trevor Hastie. (2013) "A Blockwise
514+
# Descent Algorithm for Group-penalized Multiresponse and Multinomial
515+
# Regression". https://doi.org/10.48550/arXiv.1311.6529
516+
#
517+
# We choose the standard approach and set all the coefficients of the last
518+
# class to zero, for all features including the intercept.
519+
n_classes = self.linear_loss.base_loss.n_classes
520+
n_dof = self.coef.size // n_classes # degree of freedom per class
521+
n = self.coef.size - n_dof # effective size
522+
self.coef[n_classes - 1 :: n_classes] = 0
523+
self.gradient[n_classes - 1 :: n_classes] = 0
524+
self.hessian[n_classes - 1 :: n_classes, :] = 0
525+
self.hessian[:, n_classes - 1 :: n_classes] = 0
526+
# We also need the reduced variants of gradient and hessian where the
527+
# entries set to zero are removed. For 2 features and 3 classes with
528+
# arbitrary values, "x" means removed:
529+
# gradient = [0, 1, x, 3, 4, x]
530+
#
531+
# hessian = [0, 1, x, 3, 4, x]
532+
# [1, 7, x, 9, 10, x]
533+
# [x, x, x, x, x, x]
534+
# [3, 9, x, 21, 22, x]
535+
# [4, 10, x, 22, 28, x]
536+
# [x, x, x, x, x, x]
537+
# The following slicing triggers copies of gradient and hessian.
538+
gradient = self.gradient.reshape(-1, n_classes)[:, :-1].flatten()
539+
hessian = self.hessian.reshape(n_dof, n_classes, n_dof, n_classes)[
540+
:, :-1, :, :-1
541+
].reshape(n, n)
542+
elif self.is_multinomial_with_intercept:
543+
# Here, only intercepts are unpenalized. We again choose the last class and
544+
# set its intercept to zero.
545+
self.coef[-1] = 0
546+
self.gradient[-1] = 0
547+
self.hessian[-1, :] = 0
548+
self.hessian[:, -1] = 0
549+
gradient, hessian = self.gradient[:-1], self.hessian[:-1, :-1]
550+
else:
551+
gradient, hessian = self.gradient, self.hessian
552+
482553
try:
483554
with warnings.catch_warnings():
484555
warnings.simplefilter("error", scipy.linalg.LinAlgWarning)
485556
self.coef_newton = scipy.linalg.solve(
486-
self.hessian, -self.gradient, check_finite=False, assume_a="sym"
557+
hessian, -gradient, check_finite=False, assume_a="sym"
487558
)
559+
if self.is_multinomial_no_penalty:
560+
self.coef_newton = np.c_[
561+
self.coef_newton.reshape(n_dof, n_classes - 1), np.zeros(n_dof)
562+
].reshape(-1)
563+
assert self.coef_newton.flags.f_contiguous
564+
elif self.is_multinomial_with_intercept:
565+
self.coef_newton = np.r_[self.coef_newton, 0]
488566
self.gradient_times_newton = self.gradient @ self.coef_newton
489567
if self.gradient_times_newton > 0:
490568
if self.verbose:
@@ -498,7 +576,7 @@ def inner_solve(self, X, y, sample_weight):
498576
warnings.warn(
499577
f"The inner solver of {self.__class__.__name__} stumbled upon a "
500578
"singular or very ill-conditioned Hessian matrix at iteration "
501-
f"#{self.iteration}. It will now resort to lbfgs instead.\n"
579+
f"{self.iteration}. It will now resort to lbfgs instead.\n"
502580
"Further options are to use another solver or to avoid such situation "
503581
"in the first place. Possible remedies are removing collinear features"
504582
" of X or increasing the penalization strengths.\n"
@@ -522,3 +600,17 @@ def inner_solve(self, X, y, sample_weight):
522600
)
523601
self.use_fallback_lbfgs_solve = True
524602
return
603+
604+
def finalize(self, X, y, sample_weight):
605+
if self.is_multinomial_no_penalty:
606+
# Our convention is usually the symmetric parametrization where
607+
# sum(coef[classes, features], axis=0) = 0.
608+
# We convert now to this convention. Note that it does not change
609+
# the predicted probabilities.
610+
n_classes = self.linear_loss.base_loss.n_classes
611+
self.coef = self.coef.reshape(n_classes, -1, order="F")
612+
self.coef -= np.mean(self.coef, axis=0)
613+
elif self.is_multinomial_with_intercept:
614+
# Only the intercept needs an update to the symmetric parametrization.
615+
n_classes = self.linear_loss.base_loss.n_classes
616+
self.coef[-n_classes:] -= np.mean(self.coef[-n_classes:])

0 commit comments

Comments
 (0)