Skip to content

Commit ea39024

Browse files
ferrouillemorenol
authored andcommitted
Add SVC::decision_function (#135)
1 parent 4e94feb commit ea39024

File tree

1 file changed

+57
-10
lines changed

1 file changed

+57
-10
lines changed

src/svm/svc.rs

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,33 +263,41 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
263263
/// Predicts estimated class labels from `x`
264264
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
265265
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
266-
let (n, _) = x.shape();
267-
268-
let mut y_hat = M::RowVector::zeros(n);
266+
let mut y_hat = self.decision_function(x)?;
269267

270-
for i in 0..n {
271-
let cls_idx = match self.predict_for_row(x.get_row(i)) == T::one() {
268+
for i in 0..y_hat.len() {
269+
let cls_idx = match y_hat.get(i) > T::zero() {
272270
false => self.classes[0],
273271
true => self.classes[1],
274272
};
273+
275274
y_hat.set(i, cls_idx);
276275
}
277276

278277
Ok(y_hat)
279278
}
280279

280+
/// Evaluates the decision function for the rows in `x`
281+
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
282+
pub fn decision_function(&self, x: &M) -> Result<M::RowVector, Failed> {
283+
let (n, _) = x.shape();
284+
let mut y_hat = M::RowVector::zeros(n);
285+
286+
for i in 0..n {
287+
y_hat.set(i, self.predict_for_row(x.get_row(i)));
288+
}
289+
290+
Ok(y_hat)
291+
}
292+
281293
fn predict_for_row(&self, x: M::RowVector) -> T {
282294
let mut f = self.b;
283295

284296
for i in 0..self.instances.len() {
285297
f += self.w[i] * self.kernel.apply(&x, &self.instances[i]);
286298
}
287299

288-
if f > T::zero() {
289-
T::one()
290-
} else {
291-
-T::one()
292-
}
300+
f
293301
}
294302
}
295303

@@ -772,6 +780,45 @@ mod tests {
772780
assert!(accuracy(&y_hat, &y) >= 0.9);
773781
}
774782

783+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
784+
#[test]
785+
fn svc_fit_decision_function() {
786+
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]);
787+
788+
let x2 = DenseMatrix::from_2d_array(&[
789+
&[3.0, 3.0],
790+
&[4.0, 4.0],
791+
&[6.0, 6.0],
792+
&[10.0, 10.0],
793+
&[1.0, 1.0],
794+
&[0.0, 0.0],
795+
]);
796+
797+
let y: Vec<f64> = vec![0., 0., 1., 1.];
798+
799+
let y_hat = SVC::fit(
800+
&x,
801+
&y,
802+
SVCParameters::default()
803+
.with_c(200.0)
804+
.with_kernel(Kernels::linear()),
805+
)
806+
.and_then(|lr| lr.decision_function(&x2))
807+
.unwrap();
808+
809+
// x can be classified by a straight line through [6.0, 0.0] and [0.0, 6.0],
810+
// so the score should increase as points get further away from that line
811+
println!("{:?}", y_hat);
812+
assert!(y_hat[1] < y_hat[2]);
813+
assert!(y_hat[2] < y_hat[3]);
814+
815+
// for negative scores the score should decrease
816+
assert!(y_hat[4] > y_hat[5]);
817+
818+
// y_hat[0] is on the line, so its score should be close to 0
819+
assert!(y_hat[0].abs() <= 0.1);
820+
}
821+
775822
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
776823
#[test]
777824
fn svc_fit_predict_rbf() {

0 commit comments

Comments
 (0)