52
52
//!
53
53
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
54
54
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
55
+ use std:: cmp:: Ordering ;
55
56
use std:: fmt:: Debug ;
56
57
use std:: marker:: PhantomData ;
57
58
@@ -232,51 +233,53 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
232
233
yi[ i] = classes. iter ( ) . position ( |c| yc == * c) . unwrap ( ) ;
233
234
}
234
235
235
- if k < 2 {
236
- Err ( Failed :: fit ( & format ! (
236
+ match k . cmp ( & 2 ) {
237
+ Ordering :: Less => Err ( Failed :: fit ( & format ! (
237
238
"incorrect number of classes: {}. Should be >= 2." ,
238
239
k
239
- ) ) )
240
- } else if k == 2 {
241
- let x0 = M :: zeros ( 1 , num_attributes + 1 ) ;
242
-
243
- let objective = BinaryObjectiveFunction {
244
- x : x,
245
- y : yi,
246
- phantom : PhantomData ,
247
- } ;
248
-
249
- let result = LogisticRegression :: minimize ( x0, objective) ;
250
- let weights = result. x ;
251
-
252
- Ok ( LogisticRegression {
253
- coefficients : weights. slice ( 0 ..1 , 0 ..num_attributes) ,
254
- intercept : weights. slice ( 0 ..1 , num_attributes..num_attributes + 1 ) ,
255
- classes : classes,
256
- num_attributes : num_attributes,
257
- num_classes : k,
258
- } )
259
- } else {
260
- let x0 = M :: zeros ( 1 , ( num_attributes + 1 ) * k) ;
261
-
262
- let objective = MultiClassObjectiveFunction {
263
- x : x,
264
- y : yi,
265
- k : k,
266
- phantom : PhantomData ,
267
- } ;
268
-
269
- let result = LogisticRegression :: minimize ( x0, objective) ;
270
-
271
- let weights = result. x . reshape ( k, num_attributes + 1 ) ;
272
-
273
- Ok ( LogisticRegression {
274
- coefficients : weights. slice ( 0 ..k, 0 ..num_attributes) ,
275
- intercept : weights. slice ( 0 ..k, num_attributes..num_attributes + 1 ) ,
276
- classes : classes,
277
- num_attributes : num_attributes,
278
- num_classes : k,
279
- } )
240
+ ) ) ) ,
241
+ Ordering :: Equal => {
242
+ let x0 = M :: zeros ( 1 , num_attributes + 1 ) ;
243
+
244
+ let objective = BinaryObjectiveFunction {
245
+ x : x,
246
+ y : yi,
247
+ phantom : PhantomData ,
248
+ } ;
249
+
250
+ let result = LogisticRegression :: minimize ( x0, objective) ;
251
+ let weights = result. x ;
252
+
253
+ Ok ( LogisticRegression {
254
+ coefficients : weights. slice ( 0 ..1 , 0 ..num_attributes) ,
255
+ intercept : weights. slice ( 0 ..1 , num_attributes..num_attributes + 1 ) ,
256
+ classes : classes,
257
+ num_attributes : num_attributes,
258
+ num_classes : k,
259
+ } )
260
+ }
261
+ Ordering :: Greater => {
262
+ let x0 = M :: zeros ( 1 , ( num_attributes + 1 ) * k) ;
263
+
264
+ let objective = MultiClassObjectiveFunction {
265
+ x : x,
266
+ y : yi,
267
+ k : k,
268
+ phantom : PhantomData ,
269
+ } ;
270
+
271
+ let result = LogisticRegression :: minimize ( x0, objective) ;
272
+
273
+ let weights = result. x . reshape ( k, num_attributes + 1 ) ;
274
+
275
+ Ok ( LogisticRegression {
276
+ coefficients : weights. slice ( 0 ..k, 0 ..num_attributes) ,
277
+ intercept : weights. slice ( 0 ..k, num_attributes..num_attributes + 1 ) ,
278
+ classes : classes,
279
+ num_attributes : num_attributes,
280
+ num_classes : k,
281
+ } )
282
+ }
280
283
}
281
284
}
282
285
@@ -286,7 +289,6 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
286
289
let n = x. shape ( ) . 0 ;
287
290
let mut result = M :: zeros ( 1 , n) ;
288
291
if self . num_classes == 2 {
289
- let ( nrows, _) = x. shape ( ) ;
290
292
let y_hat: Vec < T > = x. matmul ( & self . coefficients . transpose ( ) ) . get_col_as_vec ( 0 ) ;
291
293
let intercept = self . intercept . get ( 0 , 0 ) ;
292
294
for i in 0 ..n {
0 commit comments