@@ -68,7 +68,8 @@ use crate::optimization::FunctionOrder;
68
68
/// Logistic Regression
69
69
#[ derive( Serialize , Deserialize , Debug ) ]
70
70
pub struct LogisticRegression < T : RealNumber , M : Matrix < T > > {
71
- weights : M ,
71
+ coefficients : M ,
72
+ intercept : M ,
72
73
classes : Vec < T > ,
73
74
num_attributes : usize ,
74
75
num_classes : usize ,
@@ -109,7 +110,7 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
109
110
}
110
111
}
111
112
112
- return self . weights == other. weights ;
113
+ return self . coefficients == other. coefficients && self . intercept == other . intercept ;
113
114
}
114
115
}
115
116
}
@@ -246,9 +247,11 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
246
247
} ;
247
248
248
249
let result = LogisticRegression :: minimize ( x0, objective) ;
250
+ let weights = result. x ;
249
251
250
252
Ok ( LogisticRegression {
251
- weights : result. x ,
253
+ coefficients : weights. slice ( 0 ..1 , 0 ..num_attributes) ,
254
+ intercept : weights. slice ( 0 ..1 , num_attributes..num_attributes + 1 ) ,
252
255
classes : classes,
253
256
num_attributes : num_attributes,
254
257
num_classes : k,
@@ -268,7 +271,8 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
268
271
let weights = result. x . reshape ( k, num_attributes + 1 ) ;
269
272
270
273
Ok ( LogisticRegression {
271
- weights : weights,
274
+ coefficients : weights. slice ( 0 ..k, 0 ..num_attributes) ,
275
+ intercept : weights. slice ( 0 ..k, num_attributes..num_attributes + 1 ) ,
272
276
classes : classes,
273
277
num_attributes : num_attributes,
274
278
num_classes : k,
@@ -283,21 +287,26 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
283
287
let mut result = M :: zeros ( 1 , n) ;
284
288
if self . num_classes == 2 {
285
289
let ( nrows, _) = x. shape ( ) ;
286
- let x_and_bias = x. h_stack ( & M :: ones ( nrows, 1 ) ) ;
287
- let y_hat: Vec < T > = x_and_bias
288
- . matmul ( & self . weights . transpose ( ) )
289
- . get_col_as_vec ( 0 ) ;
290
+ let y_hat: Vec < T > = x. matmul ( & self . coefficients . transpose ( ) ) . get_col_as_vec ( 0 ) ;
291
+ let intercept = self . intercept . get ( 0 , 0 ) ;
290
292
for i in 0 ..n {
291
293
result. set (
292
294
0 ,
293
295
i,
294
- self . classes [ if y_hat[ i] . sigmoid ( ) > T :: half ( ) { 1 } else { 0 } ] ,
296
+ self . classes [ if ( y_hat[ i] + intercept) . sigmoid ( ) > T :: half ( ) {
297
+ 1
298
+ } else {
299
+ 0
300
+ } ] ,
295
301
) ;
296
302
}
297
303
} else {
298
- let ( nrows, _) = x. shape ( ) ;
299
- let x_and_bias = x. h_stack ( & M :: ones ( nrows, 1 ) ) ;
300
- let y_hat = x_and_bias. matmul ( & self . weights . transpose ( ) ) ;
304
+ let mut y_hat = x. matmul ( & self . coefficients . transpose ( ) ) ;
305
+ for r in 0 ..n {
306
+ for c in 0 ..self . num_classes {
307
+ y_hat. set ( r, c, y_hat. get ( r, c) + self . intercept . get ( c, 0 ) ) ;
308
+ }
309
+ }
301
310
let class_idxs = y_hat. argmax ( ) ;
302
311
for i in 0 ..n {
303
312
result. set ( 0 , i, self . classes [ class_idxs[ i] ] ) ;
@@ -307,17 +316,13 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
307
316
}
308
317
309
318
/// Get estimates regression coefficients
310
- pub fn coefficients ( & self ) -> M {
311
- self . weights
312
- . slice ( 0 ..self . num_classes , 0 ..self . num_attributes )
319
+ pub fn coefficients ( & self ) -> & M {
320
+ & self . coefficients
313
321
}
314
322
315
323
/// Get estimate of intercept
316
- pub fn intercept ( & self ) -> M {
317
- self . weights . slice (
318
- 0 ..self . num_classes ,
319
- self . num_attributes ..self . num_attributes + 1 ,
320
- )
324
+ pub fn intercept ( & self ) -> & M {
325
+ & self . intercept
321
326
}
322
327
323
328
fn minimize ( x0 : M , objective : impl ObjectiveFunction < T , M > ) -> OptimizerResult < T , M > {
@@ -336,7 +341,9 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
336
341
#[ cfg( test) ]
337
342
mod tests {
338
343
use super :: * ;
344
+ use crate :: dataset:: generator:: make_blobs;
339
345
use crate :: linalg:: naive:: dense_matrix:: * ;
346
+ use crate :: metrics:: accuracy;
340
347
341
348
#[ test]
342
349
fn multiclass_objective_f ( ) {
@@ -466,6 +473,34 @@ mod tests {
466
473
) ;
467
474
}
468
475
476
+ #[ test]
477
+ fn lr_fit_predict_multiclass ( ) {
478
+ let blobs = make_blobs ( 15 , 4 , 3 ) ;
479
+
480
+ let x = DenseMatrix :: from_vec ( 15 , 4 , & blobs. data ) ;
481
+ let y = blobs. target ;
482
+
483
+ let lr = LogisticRegression :: fit ( & x, & y) . unwrap ( ) ;
484
+
485
+ let y_hat = lr. predict ( & x) . unwrap ( ) ;
486
+
487
+ assert ! ( accuracy( & y_hat, & y) > 0.9 ) ;
488
+ }
489
+
490
+ #[ test]
491
+ fn lr_fit_predict_binary ( ) {
492
+ let blobs = make_blobs ( 20 , 4 , 2 ) ;
493
+
494
+ let x = DenseMatrix :: from_vec ( 20 , 4 , & blobs. data ) ;
495
+ let y = blobs. target ;
496
+
497
+ let lr = LogisticRegression :: fit ( & x, & y) . unwrap ( ) ;
498
+
499
+ let y_hat = lr. predict ( & x) . unwrap ( ) ;
500
+
501
+ assert ! ( accuracy( & y_hat, & y) > 0.9 ) ;
502
+ }
503
+
469
504
#[ test]
470
505
fn serde ( ) {
471
506
let x = DenseMatrix :: from_2d_array ( & [
0 commit comments