Skip to content

Commit c66d595

Browse files
authored
MNT cleaner Cython coordinate descent in _cd_fast.pyx (scikit-learn#31372)
1 parent 44b1f4d commit c66d595

File tree

1 file changed

+76
-68
lines changed

1 file changed

+76
-68
lines changed

sklearn/linear_model/_cd_fast.pyx

Lines changed: 76 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,21 @@ cdef floating diff_abs_max(int n, const floating* a, floating* b) noexcept nogil
8383
return m
8484

8585

86+
message_conv = (
87+
"Objective did not converge. You might want to increase "
88+
"the number of iterations, check the scale of the "
89+
"features or consider increasing regularisation."
90+
)
91+
92+
93+
message_ridge = (
94+
"Linear regression models with a zero l1 penalization "
95+
"strength are more efficiently fitted using one of the "
96+
"solvers implemented in "
97+
"sklearn.linear_model.Ridge/RidgeCV instead."
98+
)
99+
100+
86101
def enet_coordinate_descent(
87102
floating[::1] w,
88103
floating alpha,
@@ -141,7 +156,7 @@ def enet_coordinate_descent(
141156
cdef floating R_norm2
142157
cdef floating w_norm2
143158
cdef floating l1_norm
144-
cdef floating const
159+
cdef floating const_
145160
cdef floating A_norm2
146161
cdef unsigned int ii
147162
cdef unsigned int n_iter = 0
@@ -227,19 +242,18 @@ def enet_coordinate_descent(
227242
w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1)
228243

229244
if (dual_norm_XtA > alpha):
230-
const = alpha / dual_norm_XtA
231-
A_norm2 = R_norm2 * (const ** 2)
245+
const_ = alpha / dual_norm_XtA
246+
A_norm2 = R_norm2 * (const_ ** 2)
232247
gap = 0.5 * (R_norm2 + A_norm2)
233248
else:
234-
const = 1.0
249+
const_ = 1.0
235250
gap = R_norm2
236251

237252
l1_norm = _asum(n_features, &w[0], 1)
238253

239-
# np.dot(R.T, y)
240254
gap += (alpha * l1_norm
241-
- const * _dot(n_samples, &R[0], 1, &y[0], 1)
242-
+ 0.5 * beta * (1 + const ** 2) * (w_norm2))
255+
- const_ * _dot(n_samples, &R[0], 1, &y[0], 1) # np.dot(R.T, y)
256+
+ 0.5 * beta * (1 + const_ ** 2) * (w_norm2))
243257

244258
if gap < tol:
245259
# return if we reached desired tolerance
@@ -249,18 +263,11 @@ def enet_coordinate_descent(
249263
# for/else, runs if for doesn't end with a `break`
250264
with gil:
251265
message = (
252-
"Objective did not converge. You might want to increase "
253-
"the number of iterations, check the scale of the "
254-
"features or consider increasing regularisation. "
255-
f"Duality gap: {gap:.3e}, tolerance: {tol:.3e}"
266+
message_conv +
267+
f" Duality gap: {gap:.3e}, tolerance: {tol:.3e}"
256268
)
257269
if alpha < np.finfo(np.float64).eps:
258-
message += (
259-
" Linear regression models with null weight for the "
260-
"l1 regularization term are more efficiently fitted "
261-
"using one of the solvers implemented in "
262-
"sklearn.linear_model.Ridge/RidgeCV instead."
263-
)
270+
message += "\n" + message_ridge
264271
warnings.warn(message, ConvergenceWarning)
265272

266273
return np.asarray(w), gap, tol, n_iter + 1
@@ -313,53 +320,50 @@ def sparse_enet_coordinate_descent(
313320
# that every calculation results as if we had rescaled y and X (and therefore also
314321
# X_mean) by sqrt(sample_weight) without actually calculating the square root.
315322
# We work with:
316-
# yw = sample_weight
323+
# yw = sample_weight * y
317324
# R = sample_weight * residual
318325
# norm_cols_X = np.sum(sample_weight * (X - X_mean)**2, axis=0)
319326

327+
if floating is float:
328+
dtype = np.float32
329+
else:
330+
dtype = np.float64
331+
320332
# get the data information into easy vars
321333
cdef unsigned int n_samples = y.shape[0]
322334
cdef unsigned int n_features = w.shape[0]
323335

324336
# compute norms of the columns of X
325-
cdef unsigned int ii
326-
cdef floating[:] norm_cols_X
327-
328-
cdef unsigned int startptr = X_indptr[0]
329-
cdef unsigned int endptr
337+
cdef floating[:] norm_cols_X = np.zeros(n_features, dtype=dtype)
330338

331339
# initial value of the residuals
332340
# R = y - Zw, weighted version R = sample_weight * (y - Zw)
333341
cdef floating[::1] R
334-
cdef floating[::1] XtA
342+
cdef floating[::1] XtA = np.empty(n_features, dtype=dtype)
335343
cdef const floating[::1] yw
336344

337-
if floating is float:
338-
dtype = np.float32
339-
else:
340-
dtype = np.float64
341-
342-
norm_cols_X = np.zeros(n_features, dtype=dtype)
343-
XtA = np.zeros(n_features, dtype=dtype)
344-
345345
cdef floating tmp
346346
cdef floating w_ii
347347
cdef floating d_w_max
348348
cdef floating w_max
349349
cdef floating d_w_ii
350+
cdef floating gap = tol + 1.0
351+
cdef floating d_w_tol = tol
352+
cdef floating dual_norm_XtA
350353
cdef floating X_mean_ii
351354
cdef floating R_sum = 0.0
352355
cdef floating R_norm2
353356
cdef floating w_norm2
354-
cdef floating A_norm2
355357
cdef floating l1_norm
358+
cdef floating const_
359+
cdef floating A_norm2
356360
cdef floating normalize_sum
357-
cdef floating gap = tol + 1.0
358-
cdef floating d_w_tol = tol
359-
cdef floating dual_norm_XtA
361+
cdef unsigned int ii
360362
cdef unsigned int jj
361363
cdef unsigned int n_iter = 0
362364
cdef unsigned int f_iter
365+
cdef unsigned int startptr = X_indptr[0]
366+
cdef unsigned int endptr
363367
cdef uint32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
364368
cdef uint32_t* rand_r_state = &rand_r_state_seed
365369
cdef bint center = False
@@ -380,6 +384,7 @@ def sparse_enet_coordinate_descent(
380384
center = True
381385
break
382386

387+
# R = y - np.dot(X, w)
383388
for ii in range(n_features):
384389
X_mean_ii = X_mean[ii]
385390
endptr = X_indptr[ii + 1]
@@ -396,6 +401,7 @@ def sparse_enet_coordinate_descent(
396401
for jj in range(n_samples):
397402
R[jj] += X_mean_ii * w_ii
398403
else:
404+
# R = sw * (y - np.dot(X, w))
399405
for jj in range(startptr, endptr):
400406
tmp = sample_weight[X_indices[jj]]
401407
# second term will be subtracted by loop over range(n_samples)
@@ -526,21 +532,18 @@ def sparse_enet_coordinate_descent(
526532
# w_norm2 = np.dot(w, w)
527533
w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1)
528534
if (dual_norm_XtA > alpha):
529-
const = alpha / dual_norm_XtA
530-
A_norm2 = R_norm2 * const**2
535+
const_ = alpha / dual_norm_XtA
536+
A_norm2 = R_norm2 * const_**2
531537
gap = 0.5 * (R_norm2 + A_norm2)
532538
else:
533-
const = 1.0
539+
const_ = 1.0
534540
gap = R_norm2
535541

536542
l1_norm = _asum(n_features, &w[0], 1)
537543

538-
gap += (alpha * l1_norm - const * _dot(
539-
n_samples,
540-
&R[0], 1,
541-
&y[0], 1
542-
)
543-
+ 0.5 * beta * (1 + const ** 2) * w_norm2)
544+
gap += (alpha * l1_norm
545+
- const_ * _dot(n_samples, &R[0], 1, &y[0], 1) # np.dot(R.T, y)
546+
+ 0.5 * beta * (1 + const_ ** 2) * w_norm2)
544547

545548
if gap < tol:
546549
# return if we reached desired tolerance
@@ -549,10 +552,13 @@ def sparse_enet_coordinate_descent(
549552
else:
550553
# for/else, runs if for doesn't end with a `break`
551554
with gil:
552-
warnings.warn("Objective did not converge. You might want to "
553-
"increase the number of iterations. Duality "
554-
"gap: {}, tolerance: {}".format(gap, tol),
555-
ConvergenceWarning)
555+
message = (
556+
message_conv +
557+
f" Duality gap: {gap:.3e}, tolerance: {tol:.3e}"
558+
)
559+
if alpha < np.finfo(np.float64).eps:
560+
message += "\n" + message_ridge
561+
warnings.warn(message, ConvergenceWarning)
556562

557563
return np.asarray(w), gap, tol, n_iter + 1
558564

@@ -702,19 +708,19 @@ def enet_coordinate_descent_gram(
702708
w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1)
703709

704710
if (dual_norm_XtA > alpha):
705-
const = alpha / dual_norm_XtA
706-
A_norm2 = R_norm2 * (const ** 2)
711+
const_ = alpha / dual_norm_XtA
712+
A_norm2 = R_norm2 * (const_ ** 2)
707713
gap = 0.5 * (R_norm2 + A_norm2)
708714
else:
709-
const = 1.0
715+
const_ = 1.0
710716
gap = R_norm2
711717

712718
# The call to asum is equivalent to the L1 norm of w
713719
gap += (
714720
alpha * _asum(n_features, &w[0], 1)
715-
- const * y_norm2
716-
+ const * q_dot_w
717-
+ 0.5 * beta * (1 + const ** 2) * w_norm2
721+
- const_ * y_norm2
722+
+ const_ * q_dot_w
723+
+ 0.5 * beta * (1 + const_ ** 2) * w_norm2
718724
)
719725

720726
if gap < tol:
@@ -724,10 +730,11 @@ def enet_coordinate_descent_gram(
724730
else:
725731
# for/else, runs if for doesn't end with a `break`
726732
with gil:
727-
warnings.warn("Objective did not converge. You might want to "
728-
"increase the number of iterations. Duality "
729-
"gap: {}, tolerance: {}".format(gap, tol),
730-
ConvergenceWarning)
733+
message = (
734+
message_conv +
735+
f" Duality gap: {gap:.3e}, tolerance: {tol:.3e}"
736+
)
737+
warnings.warn(message, ConvergenceWarning)
731738

732739
return np.asarray(w), gap, tol, n_iter + 1
733740

@@ -921,11 +928,11 @@ def enet_coordinate_descent_multi_task(
921928
R_norm = _nrm2(n_samples * n_tasks, &R[0, 0], 1)
922929
w_norm = _nrm2(n_features * n_tasks, &W[0, 0], 1)
923930
if (dual_norm_XtA > l1_reg):
924-
const = l1_reg / dual_norm_XtA
925-
A_norm = R_norm * const
931+
const_ = l1_reg / dual_norm_XtA
932+
A_norm = R_norm * const_
926933
gap = 0.5 * (R_norm ** 2 + A_norm ** 2)
927934
else:
928-
const = 1.0
935+
const_ = 1.0
929936
gap = R_norm ** 2
930937

931938
# ry_sum = np.sum(R * y)
@@ -938,8 +945,8 @@ def enet_coordinate_descent_multi_task(
938945

939946
gap += (
940947
l1_reg * l21_norm
941-
- const * ry_sum
942-
+ 0.5 * l2_reg * (1 + const ** 2) * (w_norm ** 2)
948+
- const_ * ry_sum
949+
+ 0.5 * l2_reg * (1 + const_ ** 2) * (w_norm ** 2)
943950
)
944951

945952
if gap <= tol:
@@ -948,9 +955,10 @@ def enet_coordinate_descent_multi_task(
948955
else:
949956
# for/else, runs if for doesn't end with a `break`
950957
with gil:
951-
warnings.warn("Objective did not converge. You might want to "
952-
"increase the number of iterations. Duality "
953-
"gap: {}, tolerance: {}".format(gap, tol),
954-
ConvergenceWarning)
958+
message = (
959+
message_conv +
960+
f" Duality gap: {gap:.3e}, tolerance: {tol:.3e}"
961+
)
962+
warnings.warn(message, ConvergenceWarning)
955963

956964
return np.asarray(W), gap, tol, n_iter + 1

0 commit comments

Comments
 (0)