54
54
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
55
55
use std:: cmp:: Ordering ;
56
56
use std:: fmt:: Debug ;
57
- use std:: marker:: PhantomData ;
58
57
59
58
#[ cfg( feature = "serde" ) ]
60
59
use serde:: { Deserialize , Serialize } ;
@@ -79,9 +78,11 @@ pub enum LogisticRegressionSolverName {
79
78
/// Logistic Regression parameters
80
79
#[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
81
80
#[ derive( Debug , Clone ) ]
82
- pub struct LogisticRegressionParameters {
81
+ pub struct LogisticRegressionParameters < T : RealNumber > {
83
82
/// Solver to use for estimation of regression coefficients.
84
83
pub solver : LogisticRegressionSolverName ,
84
+ /// Regularization parameter.
85
+ pub alpha : T ,
85
86
}
86
87
87
88
/// Logistic Regression
@@ -113,21 +114,27 @@ trait ObjectiveFunction<T: RealNumber, M: Matrix<T>> {
113
114
struct BinaryObjectiveFunction < ' a , T : RealNumber , M : Matrix < T > > {
114
115
x : & ' a M ,
115
116
y : Vec < usize > ,
116
- phantom : PhantomData < & ' a T > ,
117
+ alpha : T ,
117
118
}
118
119
119
- impl LogisticRegressionParameters {
120
+ impl < T : RealNumber > LogisticRegressionParameters < T > {
120
121
/// Solver to use for estimation of regression coefficients.
121
122
pub fn with_solver ( mut self , solver : LogisticRegressionSolverName ) -> Self {
122
123
self . solver = solver;
123
124
self
124
125
}
126
+ /// Regularization parameter.
127
+ pub fn with_alpha ( mut self , alpha : T ) -> Self {
128
+ self . alpha = alpha;
129
+ self
130
+ }
125
131
}
126
132
127
- impl Default for LogisticRegressionParameters {
133
+ impl < T : RealNumber > Default for LogisticRegressionParameters < T > {
128
134
fn default ( ) -> Self {
129
135
LogisticRegressionParameters {
130
136
solver : LogisticRegressionSolverName :: LBFGS ,
137
+ alpha : T :: zero ( ) ,
131
138
}
132
139
}
133
140
}
@@ -156,13 +163,22 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
156
163
{
157
164
fn f ( & self , w_bias : & M ) -> T {
158
165
let mut f = T :: zero ( ) ;
159
- let ( n, _ ) = self . x . shape ( ) ;
166
+ let ( n, p ) = self . x . shape ( ) ;
160
167
161
168
for i in 0 ..n {
162
169
let wx = BinaryObjectiveFunction :: partial_dot ( w_bias, self . x , 0 , i) ;
163
170
f += wx. ln_1pe ( ) - ( T :: from ( self . y [ i] ) . unwrap ( ) ) * wx;
164
171
}
165
172
173
+ if self . alpha > T :: zero ( ) {
174
+ let mut w_squared = T :: zero ( ) ;
175
+ for i in 0 ..p {
176
+ let w = w_bias. get ( 0 , i) ;
177
+ w_squared += w * w;
178
+ }
179
+ f += T :: half ( ) * self . alpha * w_squared;
180
+ }
181
+
166
182
f
167
183
}
168
184
@@ -180,14 +196,21 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
180
196
}
181
197
g. set ( 0 , p, g. get ( 0 , p) - dyi) ;
182
198
}
199
+
200
+ if self . alpha > T :: zero ( ) {
201
+ for i in 0 ..p {
202
+ let w = w_bias. get ( 0 , i) ;
203
+ g. set ( 0 , i, g. get ( 0 , i) + self . alpha * w) ;
204
+ }
205
+ }
183
206
}
184
207
}
185
208
186
209
struct MultiClassObjectiveFunction < ' a , T : RealNumber , M : Matrix < T > > {
187
210
x : & ' a M ,
188
211
y : Vec < usize > ,
189
212
k : usize ,
190
- phantom : PhantomData < & ' a T > ,
213
+ alpha : T ,
191
214
}
192
215
193
216
impl < ' a , T : RealNumber , M : Matrix < T > > ObjectiveFunction < T , M >
@@ -209,6 +232,17 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
209
232
f -= prob. get ( 0 , self . y [ i] ) . ln ( ) ;
210
233
}
211
234
235
+ if self . alpha > T :: zero ( ) {
236
+ let mut w_squared = T :: zero ( ) ;
237
+ for i in 0 ..self . k {
238
+ for j in 0 ..p {
239
+ let wi = w_bias. get ( 0 , i * ( p + 1 ) + j) ;
240
+ w_squared += wi * wi;
241
+ }
242
+ }
243
+ f += T :: half ( ) * self . alpha * w_squared;
244
+ }
245
+
212
246
f
213
247
}
214
248
@@ -239,16 +273,27 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
239
273
g. set ( 0 , j * ( p + 1 ) + p, g. get ( 0 , j * ( p + 1 ) + p) - yi) ;
240
274
}
241
275
}
276
+
277
+ if self . alpha > T :: zero ( ) {
278
+ for i in 0 ..self . k {
279
+ for j in 0 ..p {
280
+ let pos = i * ( p + 1 ) ;
281
+ let wi = w. get ( 0 , pos + j) ;
282
+ g. set ( 0 , pos + j, g. get ( 0 , pos + j) + self . alpha * wi) ;
283
+ }
284
+ }
285
+ }
242
286
}
243
287
}
244
288
245
- impl < T : RealNumber , M : Matrix < T > > SupervisedEstimator < M , M :: RowVector , LogisticRegressionParameters >
289
+ impl < T : RealNumber , M : Matrix < T > >
290
+ SupervisedEstimator < M , M :: RowVector , LogisticRegressionParameters < T > >
246
291
for LogisticRegression < T , M >
247
292
{
248
293
fn fit (
249
294
x : & M ,
250
295
y : & M :: RowVector ,
251
- parameters : LogisticRegressionParameters ,
296
+ parameters : LogisticRegressionParameters < T > ,
252
297
) -> Result < Self , Failed > {
253
298
LogisticRegression :: fit ( x, y, parameters)
254
299
}
@@ -268,7 +313,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
268
313
pub fn fit (
269
314
x : & M ,
270
315
y : & M :: RowVector ,
271
- _parameters : LogisticRegressionParameters ,
316
+ parameters : LogisticRegressionParameters < T > ,
272
317
) -> Result < LogisticRegression < T , M > , Failed > {
273
318
let y_m = M :: from_row_vector ( y. clone ( ) ) ;
274
319
let ( x_nrows, num_attributes) = x. shape ( ) ;
@@ -302,7 +347,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
302
347
let objective = BinaryObjectiveFunction {
303
348
x,
304
349
y : yi,
305
- phantom : PhantomData ,
350
+ alpha : parameters . alpha ,
306
351
} ;
307
352
308
353
let result = LogisticRegression :: minimize ( x0, objective) ;
@@ -324,7 +369,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
324
369
x,
325
370
y : yi,
326
371
k,
327
- phantom : PhantomData ,
372
+ alpha : parameters . alpha ,
328
373
} ;
329
374
330
375
let result = LogisticRegression :: minimize ( x0, objective) ;
@@ -431,9 +476,9 @@ mod tests {
431
476
432
477
let objective = MultiClassObjectiveFunction {
433
478
x : & x,
434
- y,
479
+ y : y . clone ( ) ,
435
480
k : 3 ,
436
- phantom : PhantomData ,
481
+ alpha : 0.0 ,
437
482
} ;
438
483
439
484
let mut g: DenseMatrix < f64 > = DenseMatrix :: zeros ( 1 , 9 ) ;
@@ -454,6 +499,24 @@ mod tests {
454
499
] ) ) ;
455
500
456
501
assert ! ( ( f - 408.0052230582765 ) . abs( ) < std:: f64 :: EPSILON ) ;
502
+
503
+ let objective_reg = MultiClassObjectiveFunction {
504
+ x : & x,
505
+ y : y. clone ( ) ,
506
+ k : 3 ,
507
+ alpha : 1.0 ,
508
+ } ;
509
+
510
+ let f = objective_reg. f ( & DenseMatrix :: row_vector_from_array ( & [
511
+ 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. ,
512
+ ] ) ) ;
513
+ assert ! ( ( f - 487.5052 ) . abs( ) < 1e-4 ) ;
514
+
515
+ objective_reg. df (
516
+ & mut g,
517
+ & DenseMatrix :: row_vector_from_array ( & [ 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. ] ) ,
518
+ ) ;
519
+ assert ! ( ( g. get( 0 , 0 ) . abs( ) - 32.0 ) . abs( ) < 1e-4 ) ;
457
520
}
458
521
459
522
#[ test]
@@ -480,8 +543,8 @@ mod tests {
480
543
481
544
let objective = BinaryObjectiveFunction {
482
545
x : & x,
483
- y,
484
- phantom : PhantomData ,
546
+ y : y . clone ( ) ,
547
+ alpha : 0.0 ,
485
548
} ;
486
549
487
550
let mut g: DenseMatrix < f64 > = DenseMatrix :: zeros ( 1 , 3 ) ;
@@ -496,6 +559,20 @@ mod tests {
496
559
let f = objective. f ( & DenseMatrix :: row_vector_from_array ( & [ 1. , 2. , 3. ] ) ) ;
497
560
498
561
assert ! ( ( f - 59.76994756647412 ) . abs( ) < std:: f64 :: EPSILON ) ;
562
+
563
+ let objective_reg = BinaryObjectiveFunction {
564
+ x : & x,
565
+ y : y. clone ( ) ,
566
+ alpha : 1.0 ,
567
+ } ;
568
+
569
+ let f = objective_reg. f ( & DenseMatrix :: row_vector_from_array ( & [ 1. , 2. , 3. ] ) ) ;
570
+ assert ! ( ( f - 62.2699 ) . abs( ) < 1e-4 ) ;
571
+
572
+ objective_reg. df ( & mut g, & DenseMatrix :: row_vector_from_array ( & [ 1. , 2. , 3. ] ) ) ;
573
+ assert ! ( ( g. get( 0 , 0 ) - 27.0511 ) . abs( ) < 1e-4 ) ;
574
+ assert ! ( ( g. get( 0 , 1 ) - 12.239 ) . abs( ) < 1e-4 ) ;
575
+ assert ! ( ( g. get( 0 , 2 ) - 3.8693 ) . abs( ) < 1e-4 ) ;
499
576
}
500
577
501
578
#[ test]
@@ -547,6 +624,15 @@ mod tests {
547
624
let y_hat = lr. predict ( & x) . unwrap ( ) ;
548
625
549
626
assert ! ( accuracy( & y_hat, & y) > 0.9 ) ;
627
+
628
+ let lr_reg = LogisticRegression :: fit (
629
+ & x,
630
+ & y,
631
+ LogisticRegressionParameters :: default ( ) . with_alpha ( 10.0 ) ,
632
+ )
633
+ . unwrap ( ) ;
634
+
635
+ assert ! ( lr_reg. coefficients( ) . abs( ) . sum( ) < lr. coefficients( ) . abs( ) . sum( ) ) ;
550
636
}
551
637
552
638
#[ test]
@@ -561,6 +647,15 @@ mod tests {
561
647
let y_hat = lr. predict ( & x) . unwrap ( ) ;
562
648
563
649
assert ! ( accuracy( & y_hat, & y) > 0.9 ) ;
650
+
651
+ let lr_reg = LogisticRegression :: fit (
652
+ & x,
653
+ & y,
654
+ LogisticRegressionParameters :: default ( ) . with_alpha ( 10.0 ) ,
655
+ )
656
+ . unwrap ( ) ;
657
+
658
+ assert ! ( lr_reg. coefficients( ) . abs( ) . sum( ) < lr. coefficients( ) . abs( ) . sum( ) ) ;
564
659
}
565
660
566
661
#[ test]
@@ -622,6 +717,12 @@ mod tests {
622
717
] ;
623
718
624
719
let lr = LogisticRegression :: fit ( & x, & y, Default :: default ( ) ) . unwrap ( ) ;
720
+ let lr_reg = LogisticRegression :: fit (
721
+ & x,
722
+ & y,
723
+ LogisticRegressionParameters :: default ( ) . with_alpha ( 1.0 ) ,
724
+ )
725
+ . unwrap ( ) ;
625
726
626
727
let y_hat = lr. predict ( & x) . unwrap ( ) ;
627
728
@@ -632,5 +733,6 @@ mod tests {
632
733
. sum ( ) ;
633
734
634
735
assert ! ( error <= 1.0 ) ;
736
+ assert ! ( lr_reg. coefficients( ) . abs( ) . sum( ) < lr. coefficients( ) . abs( ) . sum( ) ) ;
635
737
}
636
738
}
0 commit comments