Skip to content

Commit bf8d0c0

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
fix: SVC: some more post-review refactoring
1 parent aa38fc8 commit bf8d0c0

File tree

1 file changed

+66
-68
lines changed

1 file changed

+66
-68
lines changed

src/svm/svc.rs

Lines changed: 66 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -468,83 +468,81 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
468468
idx_2: Option<usize>,
469469
cache: &mut Cache<T, M, K>,
470470
) -> Option<(usize, usize, T)> {
471-
let mut idx_1 = idx_1;
472-
let mut idx_2 = idx_2;
473-
474-
let mut k_v_12: Option<T> = None;
475-
476-
if idx_1.is_none() && idx_2.is_none() {
477-
self.find_min_max_gradient();
478-
if self.gmax > -self.gmin {
479-
idx_2 = Some(self.svmax);
480-
} else {
481-
idx_1 = Some(self.svmin);
482-
}
483-
}
484-
485-
if idx_2.is_none() {
486-
let idx_1 = &self.sv[idx_1.unwrap()];
487-
let km = idx_1.k;
488-
let gm = idx_1.grad;
489-
let mut best = T::zero();
490-
for i in 0..self.sv.len() {
491-
let v = &self.sv[i];
492-
let z = v.grad - gm;
493-
let k = cache.get(idx_1, &v);
494-
let mut curv = km + v.k - T::two() * k;
495-
if curv <= T::zero() {
496-
curv = self.tau;
497-
}
498-
let mu = z / curv;
499-
if (mu > T::zero() && v.alpha < v.cmax) || (mu < T::zero() && v.alpha > v.cmin) {
500-
let gain = z * mu;
501-
if gain > best {
502-
best = gain;
503-
idx_2 = Some(i);
504-
k_v_12 = Some(k);
471+
472+
match (idx_1, idx_2) {
473+
(None, None) => {
474+
if self.gmax > -self.gmin {
475+
self.select_pair(None, Some(self.svmax), cache)
476+
} else {
477+
self.select_pair(Some(self.svmin), None, cache)
478+
}
479+
},
480+
(Some(idx_1), None) => {
481+
let sv1 = &self.sv[idx_1];
482+
let mut idx_2 = None;
483+
let mut k_v_12 = None;
484+
let km = sv1.k;
485+
let gm = sv1.grad;
486+
let mut best = T::zero();
487+
for i in 0..self.sv.len() {
488+
let v = &self.sv[i];
489+
let z = v.grad - gm;
490+
let k = cache.get(sv1, &v);
491+
let mut curv = km + v.k - T::two() * k;
492+
if curv <= T::zero() {
493+
curv = self.tau;
494+
}
495+
let mu = z / curv;
496+
if (mu > T::zero() && v.alpha < v.cmax) || (mu < T::zero() && v.alpha > v.cmin) {
497+
let gain = z * mu;
498+
if gain > best {
499+
best = gain;
500+
idx_2 = Some(i);
501+
k_v_12 = Some(k);
502+
}
505503
}
506504
}
507-
}
508-
}
509505

510-
if idx_1.is_none() {
511-
let idx_2 = &self.sv[idx_2.unwrap()];
512-
let km = idx_2.k;
513-
let gm = idx_2.grad;
514-
let mut best = T::zero();
515-
for i in 0..self.sv.len() {
516-
let v = &self.sv[i];
517-
let z = gm - v.grad;
518-
let k = cache.get(idx_2, v);
519-
let mut curv = km + v.k - T::two() * k;
520-
if curv <= T::zero() {
521-
curv = self.tau;
522-
}
506+
idx_2.map(|idx_2| {
507+
(idx_1, idx_2, k_v_12.unwrap_or(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)))
508+
})
509+
},
510+
(None, Some(idx_2)) => {
511+
let mut idx_1 = None;
512+
let sv2 = &self.sv[idx_2];
513+
let mut k_v_12 = None;
514+
let km = sv2.k;
515+
let gm = sv2.grad;
516+
let mut best = T::zero();
517+
for i in 0..self.sv.len() {
518+
let v = &self.sv[i];
519+
let z = gm - v.grad;
520+
let k = cache.get(sv2, v);
521+
let mut curv = km + v.k - T::two() * k;
522+
if curv <= T::zero() {
523+
curv = self.tau;
524+
}
523525

524-
let mu = z / curv;
525-
if (mu > T::zero() && v.alpha > v.cmin) || (mu < T::zero() && v.alpha < v.cmax) {
526-
let gain = z * mu;
527-
if gain > best {
528-
best = gain;
529-
idx_1 = Some(i);
530-
k_v_12 = Some(k);
526+
let mu = z / curv;
527+
if (mu > T::zero() && v.alpha > v.cmin) || (mu < T::zero() && v.alpha < v.cmax) {
528+
let gain = z * mu;
529+
if gain > best {
530+
best = gain;
531+
idx_1 = Some(i);
532+
k_v_12 = Some(k);
533+
}
531534
}
532535
}
533-
}
534-
}
535-
536-
if idx_1.is_none() || idx_2.is_none() {
537-
None
538-
} else {
539-
let idx_1 = idx_1.unwrap();
540-
let idx_2 = idx_2.unwrap();
541536

542-
if k_v_12.is_none() {
543-
k_v_12 = Some(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x));
537+
idx_1.map(|idx_1| {
538+
(idx_1, idx_2, k_v_12.unwrap_or(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)))
539+
})
540+
},
541+
(Some(idx_1), Some(idx_2)) => {
542+
Some((idx_1, idx_2, self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)))
544543
}
545-
546-
Some((idx_1, idx_2, k_v_12.unwrap()))
547544
}
545+
548546
}
549547

550548
fn smo(

0 commit comments

Comments
 (0)