Skip to content

Commit 845e702

Browse files
yenchenlinogrisel
authored andcommitted
[MRG+2] Make enet_coordinate_descent_gram support fused types (scikit-learn#7218)
* Make cd_gram support fused types * Add test
1 parent 5b20d48 commit 845e702

File tree

2 files changed

+64
-33
lines changed

2 files changed

+64
-33
lines changed

sklearn/linear_model/cd_fast.pyx

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -528,11 +528,11 @@ def sparse_enet_coordinate_descent(floating [:] w,
528528
@cython.boundscheck(False)
529529
@cython.wraparound(False)
530530
@cython.cdivision(True)
531-
def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
532-
np.ndarray[double, ndim=2, mode='c'] Q,
533-
np.ndarray[double, ndim=1, mode='c'] q,
534-
np.ndarray[double, ndim=1] y,
535-
int max_iter, double tol, object rng,
531+
def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta,
532+
np.ndarray[floating, ndim=2, mode='c'] Q,
533+
np.ndarray[floating, ndim=1, mode='c'] q,
534+
np.ndarray[floating, ndim=1] y,
535+
int max_iter, floating tol, object rng,
536536
bint random=0, bint positive=0):
537537
"""Cython version of the coordinate descent algorithm
538538
for Elastic-Net regression
@@ -546,34 +546,52 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
546546
q = X^T y
547547
"""
548548

549+
# fused types version of BLAS functions
550+
cdef DOT dot
551+
cdef AXPY axpy
552+
cdef ASUM asum
553+
554+
if floating is float:
555+
dtype = np.float32
556+
dot = sdot
557+
axpy = saxpy
558+
asum = sasum
559+
else:
560+
dtype = np.float64
561+
dot = ddot
562+
axpy = daxpy
563+
asum = dasum
564+
549565
# get the data information into easy vars
550566
cdef unsigned int n_samples = y.shape[0]
551567
cdef unsigned int n_features = Q.shape[0]
552568

553569
# initial value "Q w" which will be kept of up to date in the iterations
554-
cdef double[:] H = np.dot(Q, w)
570+
cdef floating[:] H = np.dot(Q, w)
555571

556-
cdef double[:] XtA = np.zeros(n_features)
557-
cdef double tmp
558-
cdef double w_ii
559-
cdef double d_w_max
560-
cdef double w_max
561-
cdef double d_w_ii
562-
cdef double gap = tol + 1.0
563-
cdef double d_w_tol = tol
564-
cdef double dual_norm_XtA
572+
cdef floating[:] XtA = np.zeros(n_features, dtype=dtype)
573+
cdef floating tmp
574+
cdef floating w_ii
575+
cdef floating d_w_max
576+
cdef floating w_max
577+
cdef floating d_w_ii
578+
cdef floating q_dot_w
579+
cdef floating w_norm2
580+
cdef floating gap = tol + 1.0
581+
cdef floating d_w_tol = tol
582+
cdef floating dual_norm_XtA
565583
cdef unsigned int ii
566584
cdef unsigned int n_iter = 0
567585
cdef unsigned int f_iter
568586
cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
569587
cdef UINT32_t* rand_r_state = &rand_r_state_seed
570588

571-
cdef double y_norm2 = np.dot(y, y)
572-
cdef double* w_ptr = <double*>&w[0]
573-
cdef double* Q_ptr = &Q[0, 0]
574-
cdef double* q_ptr = <double*>q.data
575-
cdef double* H_ptr = &H[0]
576-
cdef double* XtA_ptr = &XtA[0]
589+
cdef floating y_norm2 = np.dot(y, y)
590+
cdef floating* w_ptr = <floating*>&w[0]
591+
cdef floating* Q_ptr = &Q[0, 0]
592+
cdef floating* q_ptr = <floating*>q.data
593+
cdef floating* H_ptr = &H[0]
594+
cdef floating* XtA_ptr = &XtA[0]
577595
tol = tol * y_norm2
578596

579597
if alpha == 0:
@@ -597,8 +615,8 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
597615

598616
if w_ii != 0.0:
599617
# H -= w_ii * Q[ii]
600-
daxpy(n_features, -w_ii, Q_ptr + ii * n_features, 1,
601-
H_ptr, 1)
618+
axpy(n_features, -w_ii, Q_ptr + ii * n_features, 1,
619+
H_ptr, 1)
602620

603621
tmp = q[ii] - H[ii]
604622

@@ -610,8 +628,8 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
610628

611629
if w[ii] != 0.0:
612630
# H += w[ii] * Q[ii] # Update H = X.T X w
613-
daxpy(n_features, w[ii], Q_ptr + ii * n_features, 1,
614-
H_ptr, 1)
631+
axpy(n_features, w[ii], Q_ptr + ii * n_features, 1,
632+
H_ptr, 1)
615633

616634
# update the maximum absolute coefficient update
617635
d_w_ii = fabs(w[ii] - w_ii)
@@ -627,7 +645,7 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
627645
# criterion
628646

629647
# q_dot_w = np.dot(w, q)
630-
q_dot_w = ddot(n_features, w_ptr, 1, q_ptr, 1)
648+
q_dot_w = dot(n_features, w_ptr, 1, q_ptr, 1)
631649

632650
for ii in range(n_features):
633651
XtA[ii] = q[ii] - H[ii] - beta * w[ii]
@@ -643,7 +661,7 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
643661
R_norm2 = y_norm2 + tmp - 2.0 * q_dot_w
644662

645663
# w_norm2 = np.dot(w, w)
646-
w_norm2 = ddot(n_features, &w[0], 1, &w[0], 1)
664+
w_norm2 = dot(n_features, &w[0], 1, &w[0], 1)
647665

648666
if (dual_norm_XtA > alpha):
649667
const = alpha / dual_norm_XtA
@@ -654,7 +672,7 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
654672
gap = R_norm2
655673

656674
# The call to dasum is equivalent to the L1 norm of w
657-
gap += (alpha * dasum(n_features, &w[0], 1) -
675+
gap += (alpha * asum(n_features, &w[0], 1) -
658676
const * y_norm2 + const * q_dot_w +
659677
0.5 * beta * (1 + const ** 2) * w_norm2)
660678

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,11 @@ def test_enet_float_precision():
682682
for fit_intercept in [True, False]:
683683
coef = {}
684684
intercept = {}
685-
clf = ElasticNet(alpha=0.5, max_iter=100, precompute=False,
686-
fit_intercept=fit_intercept, normalize=normalize)
687685
for dtype in [np.float64, np.float32]:
686+
clf = ElasticNet(alpha=0.5, max_iter=100, precompute=False,
687+
fit_intercept=fit_intercept,
688+
normalize=normalize)
689+
688690
X = dtype(X)
689691
y = dtype(y)
690692
ignore_warnings(clf.fit)(X, y)
@@ -694,8 +696,19 @@ def test_enet_float_precision():
694696

695697
assert_equal(clf.coef_.dtype, dtype)
696698

699+
# test precompute Gram array
700+
Gram = X.T.dot(X)
701+
clf_precompute = ElasticNet(alpha=0.5, max_iter=100,
702+
precompute=Gram,
703+
fit_intercept=fit_intercept,
704+
normalize=normalize)
705+
ignore_warnings(clf_precompute.fit)(X, y)
706+
assert_array_almost_equal(clf.coef_, clf_precompute.coef_)
707+
assert_array_almost_equal(clf.intercept_,
708+
clf_precompute.intercept_)
709+
697710
assert_array_almost_equal(coef[np.float32], coef[np.float64],
698-
decimal=4)
711+
decimal=4)
699712
assert_array_almost_equal(intercept[np.float32],
700-
intercept[np.float64],
701-
decimal=4)
713+
intercept[np.float64],
714+
decimal=4)

0 commit comments

Comments
 (0)