@@ -468,83 +468,81 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
468
468
idx_2 : Option < usize > ,
469
469
cache : & mut Cache < T , M , K > ,
470
470
) -> Option < ( usize , usize , T ) > {
471
- let mut idx_1 = idx_1;
472
- let mut idx_2 = idx_2;
473
-
474
- let mut k_v_12: Option < T > = None ;
475
-
476
- if idx_1. is_none ( ) && idx_2. is_none ( ) {
477
- self . find_min_max_gradient ( ) ;
478
- if self . gmax > -self . gmin {
479
- idx_2 = Some ( self . svmax ) ;
480
- } else {
481
- idx_1 = Some ( self . svmin ) ;
482
- }
483
- }
484
-
485
- if idx_2. is_none ( ) {
486
- let idx_1 = & self . sv [ idx_1. unwrap ( ) ] ;
487
- let km = idx_1. k ;
488
- let gm = idx_1. grad ;
489
- let mut best = T :: zero ( ) ;
490
- for i in 0 ..self . sv . len ( ) {
491
- let v = & self . sv [ i] ;
492
- let z = v. grad - gm;
493
- let k = cache. get ( idx_1, & v) ;
494
- let mut curv = km + v. k - T :: two ( ) * k;
495
- if curv <= T :: zero ( ) {
496
- curv = self . tau ;
497
- }
498
- let mu = z / curv;
499
- if ( mu > T :: zero ( ) && v. alpha < v. cmax ) || ( mu < T :: zero ( ) && v. alpha > v. cmin ) {
500
- let gain = z * mu;
501
- if gain > best {
502
- best = gain;
503
- idx_2 = Some ( i) ;
504
- k_v_12 = Some ( k) ;
471
+
472
+ match ( idx_1, idx_2) {
473
+ ( None , None ) => {
474
+ if self . gmax > -self . gmin {
475
+ self . select_pair ( None , Some ( self . svmax ) , cache)
476
+ } else {
477
+ self . select_pair ( Some ( self . svmin ) , None , cache)
478
+ }
479
+ } ,
480
+ ( Some ( idx_1) , None ) => {
481
+ let sv1 = & self . sv [ idx_1] ;
482
+ let mut idx_2 = None ;
483
+ let mut k_v_12 = None ;
484
+ let km = sv1. k ;
485
+ let gm = sv1. grad ;
486
+ let mut best = T :: zero ( ) ;
487
+ for i in 0 ..self . sv . len ( ) {
488
+ let v = & self . sv [ i] ;
489
+ let z = v. grad - gm;
490
+ let k = cache. get ( sv1, & v) ;
491
+ let mut curv = km + v. k - T :: two ( ) * k;
492
+ if curv <= T :: zero ( ) {
493
+ curv = self . tau ;
494
+ }
495
+ let mu = z / curv;
496
+ if ( mu > T :: zero ( ) && v. alpha < v. cmax ) || ( mu < T :: zero ( ) && v. alpha > v. cmin ) {
497
+ let gain = z * mu;
498
+ if gain > best {
499
+ best = gain;
500
+ idx_2 = Some ( i) ;
501
+ k_v_12 = Some ( k) ;
502
+ }
505
503
}
506
504
}
507
- }
508
- }
509
505
510
- if idx_1. is_none ( ) {
511
- let idx_2 = & self . sv [ idx_2. unwrap ( ) ] ;
512
- let km = idx_2. k ;
513
- let gm = idx_2. grad ;
514
- let mut best = T :: zero ( ) ;
515
- for i in 0 ..self . sv . len ( ) {
516
- let v = & self . sv [ i] ;
517
- let z = gm - v. grad ;
518
- let k = cache. get ( idx_2, v) ;
519
- let mut curv = km + v. k - T :: two ( ) * k;
520
- if curv <= T :: zero ( ) {
521
- curv = self . tau ;
522
- }
506
+ idx_2. map ( |idx_2| {
507
+ ( idx_1, idx_2, k_v_12. unwrap_or ( self . kernel . apply ( & self . sv [ idx_1] . x , & self . sv [ idx_2] . x ) ) )
508
+ } )
509
+ } ,
510
+ ( None , Some ( idx_2) ) => {
511
+ let mut idx_1 = None ;
512
+ let sv2 = & self . sv [ idx_2] ;
513
+ let mut k_v_12 = None ;
514
+ let km = sv2. k ;
515
+ let gm = sv2. grad ;
516
+ let mut best = T :: zero ( ) ;
517
+ for i in 0 ..self . sv . len ( ) {
518
+ let v = & self . sv [ i] ;
519
+ let z = gm - v. grad ;
520
+ let k = cache. get ( sv2, v) ;
521
+ let mut curv = km + v. k - T :: two ( ) * k;
522
+ if curv <= T :: zero ( ) {
523
+ curv = self . tau ;
524
+ }
523
525
524
- let mu = z / curv;
525
- if ( mu > T :: zero ( ) && v. alpha > v. cmin ) || ( mu < T :: zero ( ) && v. alpha < v. cmax ) {
526
- let gain = z * mu;
527
- if gain > best {
528
- best = gain;
529
- idx_1 = Some ( i) ;
530
- k_v_12 = Some ( k) ;
526
+ let mu = z / curv;
527
+ if ( mu > T :: zero ( ) && v. alpha > v. cmin ) || ( mu < T :: zero ( ) && v. alpha < v. cmax ) {
528
+ let gain = z * mu;
529
+ if gain > best {
530
+ best = gain;
531
+ idx_1 = Some ( i) ;
532
+ k_v_12 = Some ( k) ;
533
+ }
531
534
}
532
535
}
533
- }
534
- }
535
-
536
- if idx_1. is_none ( ) || idx_2. is_none ( ) {
537
- None
538
- } else {
539
- let idx_1 = idx_1. unwrap ( ) ;
540
- let idx_2 = idx_2. unwrap ( ) ;
541
536
542
- if k_v_12. is_none ( ) {
543
- k_v_12 = Some ( self . kernel . apply ( & self . sv [ idx_1] . x , & self . sv [ idx_2] . x ) ) ;
537
+ idx_1. map ( |idx_1| {
538
+ ( idx_1, idx_2, k_v_12. unwrap_or ( self . kernel . apply ( & self . sv [ idx_1] . x , & self . sv [ idx_2] . x ) ) )
539
+ } )
540
+ } ,
541
+ ( Some ( idx_1) , Some ( idx_2) ) => {
542
+ Some ( ( idx_1, idx_2, self . kernel . apply ( & self . sv [ idx_1] . x , & self . sv [ idx_2] . x ) ) )
544
543
}
545
-
546
- Some ( ( idx_1, idx_2, k_v_12. unwrap ( ) ) )
547
544
}
545
+
548
546
}
549
547
550
548
fn smo (
0 commit comments