|
41 | 41 | //!
|
42 | 42 | //! let y_hat = svr.predict(&x).unwrap();
|
43 | 43 | //! ```
|
| 44 | +//! |
| 45 | +//! ## References: |
| 46 | +//! |
| 47 | +//! * ["Support Vector Machines" Kowalczyk A., 2017](https://www.svm-tutorial.com/2017/10/support-vector-machines-succinctly-released/) |
| 48 | +//! * ["A Fast Algorithm for Training Support Vector Machines", Platt J.C., 1998](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/tr-98-14.pdf) |
| 49 | +//! * ["Working Set Selection Using Second Order Information for Training Support Vector Machines", Rong-En Fan et al., 2005](https://www.jmlr.org/papers/volume6/fan05a/fan05a.pdf) |
| 50 | +//! * ["A tutorial on support vector regression", SMOLA A.J., Scholkopf B., 2003](https://alex.smola.org/papers/2004/SmoSch04.pdf) |
| 51 | +
|
44 | 52 | use std::cell::{Ref, RefCell};
|
45 | 53 | use std::fmt::Debug;
|
46 | 54 |
|
@@ -87,6 +95,7 @@ struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
87 | 95 | k: T,
|
88 | 96 | }
|
89 | 97 |
|
| 98 | +/// Sequential Minimal Optimization algorithm |
90 | 99 | struct Optimizer<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
91 | 100 | tol: T,
|
92 | 101 | c: T,
|
@@ -135,7 +144,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVR<T, M, K> {
|
135 | 144 | )));
|
136 | 145 | }
|
137 | 146 |
|
138 |
| - let optimizer = Optimizer::optimize(x, y, &kernel, ¶meters); |
| 147 | + let optimizer = Optimizer::new(x, y, &kernel, ¶meters); |
139 | 148 |
|
140 | 149 | let (support_vectors, weight, b) = optimizer.smo();
|
141 | 150 |
|
@@ -209,7 +218,7 @@ impl<T: RealNumber, V: BaseVector<T>> SupportVector<T, V> {
|
209 | 218 | }
|
210 | 219 |
|
211 | 220 | impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a, T, M, K> {
|
212 |
| - fn optimize( |
| 221 | + fn new( |
213 | 222 | x: &M,
|
214 | 223 | y: &M::RowVector,
|
215 | 224 | kernel: &'a K,
|
@@ -244,7 +253,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
244 | 253 | }
|
245 | 254 | }
|
246 | 255 |
|
247 |
| - fn minmax(&mut self) { |
| 256 | + fn find_min_max_gradient(&mut self) { |
248 | 257 | self.gmin = T::max_value();
|
249 | 258 | self.gmax = T::min_value();
|
250 | 259 |
|
@@ -278,10 +287,14 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
278 | 287 | }
|
279 | 288 | }
|
280 | 289 |
|
| 290 | + /// Solvs the quadratic programming (QP) problem that arises during the training of support-vector machines (SVM) algorithm. |
| 291 | + /// Returns: |
| 292 | + /// * support vectors |
| 293 | + /// * hyperplane parameters: w and b |
281 | 294 | fn smo(mut self) -> (Vec<M::RowVector>, Vec<T>, T) {
|
282 | 295 | let cache: Cache<T> = Cache::new(self.sv.len());
|
283 | 296 |
|
284 |
| - self.minmax(); |
| 297 | + self.find_min_max_gradient(); |
285 | 298 |
|
286 | 299 | while self.gmax - self.gmin > self.tol {
|
287 | 300 | let v1 = self.svmax;
|
@@ -417,22 +430,22 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
417 | 430 | v.grad[1] += si * k1[v.index] * delta_alpha_i + sj * k2[v.index] * delta_alpha_j;
|
418 | 431 | }
|
419 | 432 |
|
420 |
| - self.minmax(); |
| 433 | + self.find_min_max_gradient(); |
421 | 434 | }
|
422 | 435 |
|
423 | 436 | let b = -(self.gmax + self.gmin) / T::two();
|
424 | 437 |
|
425 |
| - let mut result: Vec<M::RowVector> = Vec::new(); |
426 |
| - let mut alpha: Vec<T> = Vec::new(); |
| 438 | + let mut support_vectors: Vec<M::RowVector> = Vec::new(); |
| 439 | + let mut w: Vec<T> = Vec::new(); |
427 | 440 |
|
428 | 441 | for v in self.sv {
|
429 | 442 | if v.alpha[0] != v.alpha[1] {
|
430 |
| - result.push(v.x); |
431 |
| - alpha.push(v.alpha[1] - v.alpha[0]); |
| 443 | + support_vectors.push(v.x); |
| 444 | + w.push(v.alpha[1] - v.alpha[0]); |
432 | 445 | }
|
433 | 446 | }
|
434 | 447 |
|
435 |
| - (result, alpha, b) |
| 448 | + (support_vectors, w, b) |
436 | 449 | }
|
437 | 450 | }
|
438 | 451 |
|
@@ -497,8 +510,6 @@ mod tests {
|
497 | 510 | .and_then(|lr| lr.predict(&x))
|
498 | 511 | .unwrap();
|
499 | 512 |
|
500 |
| - println!("{:?}", y_hat); |
501 |
| - |
502 | 513 | assert!(mean_squared_error(&y_hat, &y) < 2.5);
|
503 | 514 | }
|
504 | 515 |
|
|
0 commit comments