@@ -263,33 +263,41 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
263
263
/// Predicts estimated class labels from `x`
264
264
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
265
265
pub fn predict ( & self , x : & M ) -> Result < M :: RowVector , Failed > {
266
- let ( n, _) = x. shape ( ) ;
267
-
268
- let mut y_hat = M :: RowVector :: zeros ( n) ;
266
+ let mut y_hat = self . decision_function ( x) ?;
269
267
270
- for i in 0 ..n {
271
- let cls_idx = match self . predict_for_row ( x . get_row ( i ) ) == T :: one ( ) {
268
+ for i in 0 ..y_hat . len ( ) {
269
+ let cls_idx = match y_hat . get ( i ) > T :: zero ( ) {
272
270
false => self . classes [ 0 ] ,
273
271
true => self . classes [ 1 ] ,
274
272
} ;
273
+
275
274
y_hat. set ( i, cls_idx) ;
276
275
}
277
276
278
277
Ok ( y_hat)
279
278
}
280
279
280
+ /// Evaluates the decision function for the rows in `x`
281
+ /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
282
+ pub fn decision_function ( & self , x : & M ) -> Result < M :: RowVector , Failed > {
283
+ let ( n, _) = x. shape ( ) ;
284
+ let mut y_hat = M :: RowVector :: zeros ( n) ;
285
+
286
+ for i in 0 ..n {
287
+ y_hat. set ( i, self . predict_for_row ( x. get_row ( i) ) ) ;
288
+ }
289
+
290
+ Ok ( y_hat)
291
+ }
292
+
281
293
fn predict_for_row ( & self , x : M :: RowVector ) -> T {
282
294
let mut f = self . b ;
283
295
284
296
for i in 0 ..self . instances . len ( ) {
285
297
f += self . w [ i] * self . kernel . apply ( & x, & self . instances [ i] ) ;
286
298
}
287
299
288
- if f > T :: zero ( ) {
289
- T :: one ( )
290
- } else {
291
- -T :: one ( )
292
- }
300
+ f
293
301
}
294
302
}
295
303
@@ -772,6 +780,45 @@ mod tests {
772
780
assert ! ( accuracy( & y_hat, & y) >= 0.9 ) ;
773
781
}
774
782
783
+ #[ cfg_attr( target_arch = "wasm32" , wasm_bindgen_test:: wasm_bindgen_test) ]
784
+ #[ test]
785
+ fn svc_fit_decision_function ( ) {
786
+ let x = DenseMatrix :: from_2d_array ( & [ & [ 4.0 , 0.0 ] , & [ 0.0 , 4.0 ] , & [ 8.0 , 0.0 ] , & [ 0.0 , 8.0 ] ] ) ;
787
+
788
+ let x2 = DenseMatrix :: from_2d_array ( & [
789
+ & [ 3.0 , 3.0 ] ,
790
+ & [ 4.0 , 4.0 ] ,
791
+ & [ 6.0 , 6.0 ] ,
792
+ & [ 10.0 , 10.0 ] ,
793
+ & [ 1.0 , 1.0 ] ,
794
+ & [ 0.0 , 0.0 ] ,
795
+ ] ) ;
796
+
797
+ let y: Vec < f64 > = vec ! [ 0. , 0. , 1. , 1. ] ;
798
+
799
+ let y_hat = SVC :: fit (
800
+ & x,
801
+ & y,
802
+ SVCParameters :: default ( )
803
+ . with_c ( 200.0 )
804
+ . with_kernel ( Kernels :: linear ( ) ) ,
805
+ )
806
+ . and_then ( |lr| lr. decision_function ( & x2) )
807
+ . unwrap ( ) ;
808
+
809
+ // x can be classified by a straight line through [6.0, 0.0] and [0.0, 6.0],
810
+ // so the score should increase as points get further away from that line
811
+ println ! ( "{:?}" , y_hat) ;
812
+ assert ! ( y_hat[ 1 ] < y_hat[ 2 ] ) ;
813
+ assert ! ( y_hat[ 2 ] < y_hat[ 3 ] ) ;
814
+
815
+ // for negative scores the score should decrease
816
+ assert ! ( y_hat[ 4 ] > y_hat[ 5 ] ) ;
817
+
818
+ // y_hat[0] is on the line, so its score should be close to 0
819
+ assert ! ( y_hat[ 0 ] . abs( ) <= 0.1 ) ;
820
+ }
821
+
775
822
#[ cfg_attr( target_arch = "wasm32" , wasm_bindgen_test:: wasm_bindgen_test) ]
776
823
#[ test]
777
824
fn svc_fit_predict_rbf ( ) {
0 commit comments