1
1
#![ feature( trait_alias) ]
2
2
#![ allow( clippy:: len_without_is_empty) ]
3
3
4
+ pub mod visited;
5
+
4
6
use base:: index:: * ;
5
7
use base:: operator:: * ;
6
8
use base:: scalar:: F32 ;
7
9
use base:: search:: * ;
8
10
use bytemuck:: { Pod , Zeroable } ;
9
11
use common:: dir_ops:: sync_dir;
10
12
use common:: mmap_array:: MmapArray ;
11
- use parking_lot:: { Mutex , RwLock , RwLockWriteGuard } ;
13
+ use parking_lot:: { RwLock , RwLockWriteGuard } ;
12
14
use quantization:: operator:: OperatorQuantization ;
13
15
use quantization:: Quantization ;
14
16
use rayon:: iter:: { IntoParallelIterator , ParallelIterator } ;
@@ -20,6 +22,7 @@ use std::path::Path;
20
22
use std:: sync:: Arc ;
21
23
use storage:: operator:: OperatorStorage ;
22
24
use storage:: StorageCollection ;
25
+ use visited:: { VisitedGuard , VisitedPool } ;
23
26
24
27
pub trait OperatorHnsw = Operator + OperatorQuantization + OperatorStorage ;
25
28
@@ -28,6 +31,11 @@ pub struct Hnsw<O: OperatorHnsw> {
28
31
}
29
32
30
33
impl < O : OperatorHnsw > Hnsw < O > {
34
+ #[ cfg( feature = "stand-alone-test" ) ]
35
+ pub fn new ( mmap : HnswMmap < O > ) -> Self {
36
+ Self { mmap }
37
+ }
38
+
31
39
pub fn create < S : Source < O > > ( path : & Path , options : IndexOptions , source : & S ) -> Self {
32
40
create_dir ( path) . unwrap ( ) ;
33
41
let ram = make ( path, options, source) ;
@@ -86,22 +94,41 @@ pub struct HnswRam<O: OperatorHnsw> {
86
94
visited : VisitedPool ,
87
95
}
88
96
89
- struct HnswRamGraph {
90
- vertexs : Vec < HnswRamVertex > ,
97
+ impl < O : OperatorHnsw > HnswRam < O > {
98
+ #[ cfg( feature = "stand-alone-test" ) ]
99
+ pub fn new (
100
+ storage : Arc < StorageCollection < O > > ,
101
+ quantization : Quantization < O , StorageCollection < O > > ,
102
+ m : u32 ,
103
+ graph : HnswRamGraph ,
104
+ visited : VisitedPool ,
105
+ ) -> Self {
106
+ Self {
107
+ storage,
108
+ quantization,
109
+ m,
110
+ graph,
111
+ visited,
112
+ }
113
+ }
114
+ }
115
+
116
+ pub struct HnswRamGraph {
117
+ pub vertexs : Vec < HnswRamVertex > ,
91
118
}
92
119
93
- struct HnswRamVertex {
94
- layers : Vec < RwLock < HnswRamLayer > > ,
120
+ pub struct HnswRamVertex {
121
+ pub layers : Vec < RwLock < HnswRamLayer > > ,
95
122
}
96
123
97
124
impl HnswRamVertex {
98
- fn levels ( & self ) -> u8 {
125
+ pub fn levels ( & self ) -> u8 {
99
126
self . layers . len ( ) as u8 - 1
100
127
}
101
128
}
102
129
103
- struct HnswRamLayer {
104
- edges : Vec < ( F32 , u32 ) > ,
130
+ pub struct HnswRamLayer {
131
+ pub edges : Vec < ( F32 , u32 ) > ,
105
132
}
106
133
107
134
pub struct HnswMmap < O : OperatorHnsw > {
@@ -117,8 +144,31 @@ pub struct HnswMmap<O: OperatorHnsw> {
117
144
visited : VisitedPool ,
118
145
}
119
146
147
+ impl < O : OperatorHnsw > HnswMmap < O > {
148
+ #[ cfg( feature = "stand-alone-test" ) ]
149
+ pub fn new (
150
+ storage : Arc < StorageCollection < O > > ,
151
+ quantization : Quantization < O , StorageCollection < O > > ,
152
+ m : u32 ,
153
+ edges : MmapArray < HnswMmapEdge > ,
154
+ by_layer_id : MmapArray < usize > ,
155
+ by_vertex_id : MmapArray < usize > ,
156
+ visited : VisitedPool ,
157
+ ) -> Self {
158
+ Self {
159
+ storage,
160
+ quantization,
161
+ m,
162
+ edges,
163
+ by_layer_id,
164
+ by_vertex_id,
165
+ visited,
166
+ }
167
+ }
168
+ }
169
+
120
170
#[ derive( Debug , Clone , Copy , Default ) ]
121
- struct HnswMmapEdge ( #[ allow( dead_code) ] F32 , u32 ) ;
171
+ pub struct HnswMmapEdge ( #[ allow( dead_code) ] F32 , u32 ) ;
122
172
// we may convert a memory-mapped graph to a memory graph
123
173
// so that it speeds merging sealed segments
124
174
@@ -574,7 +624,7 @@ pub fn local_search_vbase<'a, O: OperatorHnsw>(
574
624
} )
575
625
}
576
626
577
- fn count_layers_of_a_vertex ( m : u32 , i : u32 ) -> u8 {
627
+ pub fn count_layers_of_a_vertex ( m : u32 , i : u32 ) -> u8 {
578
628
let mut x = i + 1 ;
579
629
let mut ans = 1 ;
580
630
while x % m == 0 {
@@ -584,7 +634,7 @@ fn count_layers_of_a_vertex(m: u32, i: u32) -> u8 {
584
634
ans
585
635
}
586
636
587
- fn count_max_edges_of_a_layer ( m : u32 , j : u8 ) -> u32 {
637
+ pub fn count_max_edges_of_a_layer ( m : u32 , j : u8 ) -> u32 {
588
638
if j == 0 {
589
639
m * 2
590
640
} else {
@@ -610,123 +660,6 @@ fn find_edges<O: OperatorHnsw>(mmap: &HnswMmap<O>, u: u32, level: u8) -> &[HnswM
610
660
& mmap. edges [ index]
611
661
}
612
662
613
- struct VisitedPool {
614
- n : u32 ,
615
- locked_buffers : Mutex < Vec < VisitedBuffer > > ,
616
- }
617
-
618
- impl VisitedPool {
619
- pub fn new ( n : u32 ) -> Self {
620
- Self {
621
- n,
622
- locked_buffers : Mutex :: new ( Vec :: new ( ) ) ,
623
- }
624
- }
625
- pub fn fetch ( & self ) -> VisitedGuard {
626
- let buffer = self
627
- . locked_buffers
628
- . lock ( )
629
- . pop ( )
630
- . unwrap_or_else ( || VisitedBuffer :: new ( self . n as _ ) ) ;
631
- VisitedGuard { buffer, pool : self }
632
- }
633
-
634
- fn fetch2 ( & self ) -> VisitedGuardChecker {
635
- let mut buffer = self
636
- . locked_buffers
637
- . lock ( )
638
- . pop ( )
639
- . unwrap_or_else ( || VisitedBuffer :: new ( self . n as _ ) ) ;
640
- {
641
- buffer. version = buffer. version . wrapping_add ( 1 ) ;
642
- if buffer. version == 0 {
643
- buffer. data . fill ( 0 ) ;
644
- }
645
- }
646
- VisitedGuardChecker { buffer, pool : self }
647
- }
648
- }
649
-
650
- struct VisitedGuard < ' a > {
651
- buffer : VisitedBuffer ,
652
- pool : & ' a VisitedPool ,
653
- }
654
-
655
- impl < ' a > VisitedGuard < ' a > {
656
- fn fetch ( & mut self ) -> VisitedChecker < ' _ > {
657
- self . buffer . version = self . buffer . version . wrapping_add ( 1 ) ;
658
- if self . buffer . version == 0 {
659
- self . buffer . data . fill ( 0 ) ;
660
- }
661
- VisitedChecker {
662
- buffer : & mut self . buffer ,
663
- }
664
- }
665
- }
666
-
667
- impl < ' a > Drop for VisitedGuard < ' a > {
668
- fn drop ( & mut self ) {
669
- let src = VisitedBuffer {
670
- version : 0 ,
671
- data : Vec :: new ( ) ,
672
- } ;
673
- let buffer = std:: mem:: replace ( & mut self . buffer , src) ;
674
- self . pool . locked_buffers . lock ( ) . push ( buffer) ;
675
- }
676
- }
677
-
678
- struct VisitedChecker < ' a > {
679
- buffer : & ' a mut VisitedBuffer ,
680
- }
681
-
682
- impl < ' a > VisitedChecker < ' a > {
683
- fn check ( & mut self , i : u32 ) -> bool {
684
- self . buffer . data [ i as usize ] != self . buffer . version
685
- }
686
- fn mark ( & mut self , i : u32 ) {
687
- self . buffer . data [ i as usize ] = self . buffer . version ;
688
- }
689
- }
690
-
691
- struct VisitedGuardChecker < ' a > {
692
- buffer : VisitedBuffer ,
693
- pool : & ' a VisitedPool ,
694
- }
695
-
696
- impl < ' a > VisitedGuardChecker < ' a > {
697
- fn check ( & mut self , i : u32 ) -> bool {
698
- self . buffer . data [ i as usize ] != self . buffer . version
699
- }
700
- fn mark ( & mut self , i : u32 ) {
701
- self . buffer . data [ i as usize ] = self . buffer . version ;
702
- }
703
- }
704
-
705
- impl < ' a > Drop for VisitedGuardChecker < ' a > {
706
- fn drop ( & mut self ) {
707
- let src = VisitedBuffer {
708
- version : 0 ,
709
- data : Vec :: new ( ) ,
710
- } ;
711
- let buffer = std:: mem:: replace ( & mut self . buffer , src) ;
712
- self . pool . locked_buffers . lock ( ) . push ( buffer) ;
713
- }
714
- }
715
-
716
- struct VisitedBuffer {
717
- version : usize ,
718
- data : Vec < usize > ,
719
- }
720
-
721
- impl VisitedBuffer {
722
- fn new ( capacity : usize ) -> Self {
723
- Self {
724
- version : 0 ,
725
- data : bytemuck:: zeroed_vec ( capacity) ,
726
- }
727
- }
728
- }
729
-
730
663
pub struct ElementHeap {
731
664
binary_heap : BinaryHeap < Element > ,
732
665
k : usize ,
0 commit comments