1
1
use super :: metrics;
2
- use crate :: types;
3
2
4
3
use num:: Float ;
5
4
use ordered_float:: OrderedFloat ;
@@ -13,7 +12,14 @@ use std::fmt;
13
12
use std:: hash:: { Hash , Hasher } ;
14
13
use std:: rc:: Rc ;
15
14
use std:: sync:: { Arc , RwLock , Weak } ;
16
- use std:: thread;
15
+ // use std::thread;
16
+
17
+ struct SelectParams {
18
+ m : usize ,
19
+ lc : usize ,
20
+ extend_candidates : bool ,
21
+ keep_pruned_connections : bool ,
22
+ }
17
23
18
24
#[ derive( Debug ) ]
19
25
pub enum HNSWError {
@@ -48,7 +54,7 @@ pub struct SearchResult<T: Float, R: Float> {
48
54
impl < T : Float , R : Float > SearchResult < T , R > {
49
55
fn new ( sim : OrderedFloat < R > , name : & str , data : & [ T ] ) -> Self {
50
56
SearchResult {
51
- sim : sim ,
57
+ sim,
52
58
name : name. to_owned ( ) ,
53
59
data : data. to_vec ( ) ,
54
60
}
@@ -108,7 +114,7 @@ where
108
114
self . neighbors
109
115
. iter( )
110
116
. map( |l| {
111
- l. into_iter ( )
117
+ l. iter ( )
112
118
. map( |n| n. upgrade( ) . read( ) . name. to_owned( ) )
113
119
. collect:: <Vec <String >>( )
114
120
} )
@@ -249,8 +255,8 @@ where
249
255
{
250
256
fn new ( sim : OrderedFloat < R > , node : Node < T > ) -> Self {
251
257
let sp = _SimPair {
252
- sim : sim ,
253
- node : node ,
258
+ sim,
259
+ node,
254
260
} ;
255
261
SimPair ( Rc :: new ( RefCell :: new ( sp) ) )
256
262
}
@@ -311,7 +317,7 @@ pub struct Index<T: Float, R: Float> {
311
317
pub layers : Vec < HashSet < NodeWeak < T > > > , // distinct nodes in each layer
312
318
pub nodes : HashMap < String , Node < T > > , // hashmap of nodes
313
319
pub enterpoint : Option < NodeWeak < T > > , // enterpoint node
314
- rng_ : StdRng , // rng for level generation
320
+ pub rng_ : StdRng , // rng for level generation
315
321
}
316
322
317
323
impl < T : Float , R : Float > Index < T , R > {
@@ -324,13 +330,13 @@ impl<T: Float, R: Float> Index<T, R> {
324
330
) -> Self {
325
331
Index {
326
332
name : name. to_string ( ) ,
327
- mfunc : mfunc ,
333
+ mfunc,
328
334
mfunc_kind : metrics:: MetricFuncs :: Euclidean ,
329
- data_dim : data_dim ,
330
- m : m ,
335
+ data_dim,
336
+ m,
331
337
m_max : m,
332
338
m_max_0 : m * 2 ,
333
- ef_construction : ef_construction ,
339
+ ef_construction,
334
340
level_mult : 1.0 / ( 1.0 * m as f64 ) . ln ( ) ,
335
341
node_count : 0 ,
336
342
max_layer : 0 ,
@@ -371,35 +377,6 @@ impl<T: Float, R: Float> fmt::Debug for Index<T, R> {
371
377
}
372
378
}
373
379
374
- impl From < & types:: IndexRedis > for Index < f32 , f32 > {
375
- fn from ( index : & types:: IndexRedis ) -> Self {
376
- Index {
377
- name : index. name . clone ( ) ,
378
- mfunc : match index. mfunc_kind . as_str ( ) {
379
- "Euclidean" => Box :: new ( metrics:: euclidean) ,
380
- _ => Box :: new ( metrics:: euclidean) ,
381
- } ,
382
- mfunc_kind : match index. mfunc_kind . as_str ( ) {
383
- "Euclidean" => metrics:: MetricFuncs :: Euclidean ,
384
- _ => metrics:: MetricFuncs :: Euclidean ,
385
- } ,
386
- data_dim : index. data_dim ,
387
- m : index. m ,
388
- m_max : index. m_max ,
389
- m_max_0 : index. m_max_0 ,
390
- ef_construction : index. ef_construction ,
391
- level_mult : index. level_mult ,
392
- node_count : index. node_count ,
393
- max_layer : index. max_layer ,
394
- // the next 3 need to be populated from redis
395
- layers : Vec :: new ( ) ,
396
- nodes : HashMap :: new ( ) ,
397
- enterpoint : None ,
398
- rng_ : StdRng :: from_entropy ( ) ,
399
- }
400
- }
401
- }
402
-
403
380
impl < T , R > Index < T , R >
404
381
where
405
382
T : Float + Send + Sync + ' static ,
@@ -409,7 +386,7 @@ where
409
386
& mut self ,
410
387
name : & str ,
411
388
data : & [ T ] ,
412
- update_fn : fn ( String , Node < T > ) ,
389
+ update_fn : impl Fn ( String , Node < T > ) ,
413
390
) -> Result < ( ) , HNSWError > {
414
391
if data. len ( ) != self . data_dim {
415
392
return Err ( format ! ( "data dimension: {} does not match Index" , data. len( ) ) . into ( ) ) ;
@@ -429,7 +406,7 @@ where
429
406
return Ok ( ( ) ) ;
430
407
}
431
408
432
- if ! self . nodes . get ( name) . is_none ( ) {
409
+ if self . nodes . get ( name) . is_some ( ) {
433
410
return Err ( format ! ( "Node: {:?} already exists" , name) . into ( ) ) ;
434
411
}
435
412
@@ -439,7 +416,7 @@ where
439
416
pub fn delete_node (
440
417
& mut self ,
441
418
name : & str ,
442
- update_fn : fn ( String , Node < T > ) ,
419
+ update_fn : impl Fn ( String , Node < T > ) ,
443
420
) -> Result < ( ) , HNSWError > {
444
421
let node = match self . nodes . remove ( name) {
445
422
Some ( node) => node,
@@ -467,7 +444,7 @@ where
467
444
for n in updated {
468
445
let name = n. read ( ) . name . clone ( ) ;
469
446
let node = n. clone ( ) ;
470
- let _ = thread :: spawn ( move || update_fn ( name, node) ) ;
447
+ update_fn ( name, node) ;
471
448
}
472
449
473
450
// update enterpoint if necessary
@@ -515,7 +492,7 @@ where
515
492
& mut self ,
516
493
name : & str ,
517
494
data : & [ T ] ,
518
- update_fn : fn ( String , Node < T > ) ,
495
+ update_fn : impl Fn ( String , Node < T > ) ,
519
496
) -> Result < ( ) , HNSWError > {
520
497
let l = self . gen_random_level ( ) ;
521
498
let l_max = self . max_layer ;
@@ -547,7 +524,13 @@ where
547
524
let mut updated = HashSet :: new ( ) ;
548
525
for lc in ( 0 ..( min ( l_max, l) + 1 ) ) . rev ( ) {
549
526
w = self . search_level ( data, & ep. upgrade ( ) , self . ef_construction , lc) ;
550
- let mut neighbors = self . select_neighbors ( query, & w, self . m , lc, true , true , None ) ;
527
+ let params = SelectParams {
528
+ m : self . m ,
529
+ lc,
530
+ extend_candidates : true ,
531
+ keep_pruned_connections : true
532
+ } ;
533
+ let mut neighbors = self . select_neighbors ( query, & w, params, None ) ;
551
534
self . connect_neighbors ( query, & neighbors, lc) ;
552
535
553
536
// add node to list of nodes to be updated in redis
@@ -578,8 +561,14 @@ where
578
561
579
562
let m_max = if lc == 0 { self . m_max_0 } else { self . m_max } ;
580
563
if econn. len ( ) > m_max {
564
+ let params = SelectParams {
565
+ m : m_max,
566
+ lc,
567
+ extend_candidates : true ,
568
+ keep_pruned_connections : true
569
+ } ;
581
570
let enewconn =
582
- self . select_neighbors ( & er. node , & econn, m_max , lc , true , true , None ) ;
571
+ self . select_neighbors ( & er. node , & econn, params , None ) ;
583
572
let up = self . update_node_connections ( & er. node , & enewconn, & econn, lc, None ) ;
584
573
for u in up {
585
574
updated. insert ( u) ;
@@ -594,7 +583,7 @@ where
594
583
for n in updated {
595
584
let name = n. read ( ) . name . clone ( ) ;
596
585
let node = n. clone ( ) ;
597
- let _ = thread :: spawn ( move || update_fn ( name, node) ) ;
586
+ update_fn ( name, node) ;
598
587
}
599
588
600
589
// new enterpoint if we're in a higher layer
@@ -692,18 +681,15 @@ where
692
681
& self ,
693
682
query : & Node < T > ,
694
683
c : & BinaryHeap < SimPair < T , R > > ,
695
- m : usize ,
696
- lc : usize ,
697
- extend_candidates : bool ,
698
- keep_pruned_connections : bool ,
684
+ params : SelectParams ,
699
685
ignored_node : Option < & Node < T > > ,
700
686
) -> BinaryHeap < SimPair < T , R > > {
701
- let mut r: BinaryHeap < SimPair < T , R > > = BinaryHeap :: with_capacity ( m) ;
687
+ let mut r: BinaryHeap < SimPair < T , R > > = BinaryHeap :: with_capacity ( params . m ) ;
702
688
let mut w = c. clone ( ) ;
703
689
let mut wd = BinaryHeap :: new ( ) ;
704
690
705
691
// extend candidates by their neighbors
706
- if extend_candidates {
692
+ if params . extend_candidates {
707
693
let mut ccopy = c. clone ( ) ;
708
694
709
695
let mut v = HashSet :: with_capacity ( ccopy. capacity ( ) ) ;
@@ -716,7 +702,7 @@ where
716
702
while !ccopy. is_empty ( ) {
717
703
let epair = ccopy. pop ( ) . unwrap ( ) ;
718
704
719
- for eneighbor in & epair. read ( ) . node . read ( ) . neighbors [ lc] {
705
+ for eneighbor in & epair. read ( ) . node . read ( ) . neighbors [ params . lc ] {
720
706
let eneighbor = eneighbor. upgrade ( ) ;
721
707
if eneighbor == * query
722
708
|| ( ignored_node. is_some ( ) && eneighbor == * ignored_node. unwrap ( ) )
@@ -738,7 +724,7 @@ where
738
724
}
739
725
}
740
726
741
- while !w. is_empty ( ) && r. len ( ) < m {
727
+ while !w. is_empty ( ) && r. len ( ) < params . m {
742
728
let epair = w. pop ( ) . unwrap ( ) ;
743
729
let enr = epair. read ( ) ;
744
730
@@ -755,8 +741,8 @@ where
755
741
}
756
742
757
743
// add back some of the discarded connections
758
- if keep_pruned_connections {
759
- while !wd. is_empty ( ) && r. len ( ) < m {
744
+ if params . keep_pruned_connections {
745
+ while !wd. is_empty ( ) && r. len ( ) < params . m {
760
746
let ppair = wd. pop ( ) . unwrap ( ) ;
761
747
{
762
748
let pr = ppair. read ( ) ;
@@ -813,11 +799,8 @@ where
813
799
updated. insert ( npr. node . clone ( ) ) ;
814
800
// if new neighbor exists in the old set then we remove it from
815
801
// the set of neighbors to be removed
816
- match rmconn. iter ( ) . position ( |n| n. read ( ) . node == npr. node ) {
817
- Some ( index) => {
818
- rmconn. remove ( index) ;
819
- }
820
- None => ( ) ,
802
+ if let Some ( index) = rmconn. iter ( ) . position ( |n| n. read ( ) . node == npr. node ) {
803
+ rmconn. remove ( index) ;
821
804
}
822
805
}
823
806
@@ -864,7 +847,13 @@ where
864
847
}
865
848
866
849
let m_max = if lc == 0 { self . m_max_0 } else { self . m_max } ;
867
- nnewconn = self . select_neighbors ( & n, & nconn, m_max, lc, true , true , Some ( node) ) ;
850
+ let params = SelectParams {
851
+ m : m_max,
852
+ lc,
853
+ extend_candidates : true ,
854
+ keep_pruned_connections : true
855
+ } ;
856
+ nnewconn = self . select_neighbors ( & n, & nconn, params, Some ( node) ) ;
868
857
}
869
858
updated. insert ( n. clone ( ) ) ;
870
859
let up = self . update_node_connections ( & n, & nnewconn, & nconn, lc, Some ( node) ) ;
@@ -896,7 +885,7 @@ where
896
885
let cnr = cr. node . read ( ) ;
897
886
res. push ( SearchResult :: new (
898
887
cr. sim ,
899
- & ( ( & cnr. name ) . split ( "." ) . collect :: < Vec < & str > > ( ) )
888
+ & ( ( & cnr. name ) . split ( '.' ) . collect :: < Vec < & str > > ( ) )
900
889
. last ( )
901
890
. unwrap ( ) ,
902
891
& cnr. data ,
0 commit comments