@@ -3,6 +3,7 @@ use bytes::Bytes;
3
3
use futures:: StreamExt ;
4
4
use futures:: TryStreamExt ;
5
5
use hnsw:: Hnsw ;
6
+ use hyper:: HeaderMap ;
6
7
use hyper:: StatusCode ;
7
8
use hyper:: {
8
9
service:: { make_service_fn, service_fn} ,
@@ -28,6 +29,7 @@ use std::{
28
29
} ;
29
30
use thiserror:: Error ;
30
31
use tokio:: sync:: Mutex ;
32
+ use tokio:: task;
31
33
use tokio:: { io:: AsyncBufReadExt , sync:: RwLock } ;
32
34
use tokio_stream:: { wrappers:: LinesStream , Stream } ;
33
35
use tokio_util:: io:: StreamReader ;
@@ -39,6 +41,7 @@ use crate::indexer::search;
39
41
use crate :: indexer:: serialize_index;
40
42
use crate :: indexer:: Point ;
41
43
use crate :: indexer:: PointOperation ;
44
+ use crate :: indexer:: SearchError ;
42
45
use crate :: indexer:: { start_indexing_from_operations, HnswIndex , IndexIdentifier , OpenAI } ;
43
46
use crate :: openai:: embeddings_for;
44
47
use crate :: vectors:: VectorStore ;
@@ -113,6 +116,28 @@ fn query_map(uri: &Uri) -> HashMap<String, String> {
113
116
. unwrap_or_else ( || HashMap :: with_capacity ( 0 ) )
114
117
}
115
118
119
+ #[ derive( Debug , Error ) ]
120
+ enum HeaderError {
121
+ #[ error( "Key was not valid utf8" ) ]
122
+ KeyNotUtf8 ,
123
+ #[ error( "Missing the key {0}" ) ]
124
+ MissingKey ( String ) ,
125
+ }
126
+
127
+ fn get_header_value ( header : & HeaderMap , key : & str ) -> Result < String , HeaderError > {
128
+ let value = header. get ( key) ;
129
+ match value {
130
+ Some ( value) => {
131
+ let value = String :: from_utf8 ( value. as_bytes ( ) . to_vec ( ) ) ;
132
+ match value {
133
+ Ok ( value) => Ok ( value) ,
134
+ Err ( _) => Err ( HeaderError :: KeyNotUtf8 ) ,
135
+ }
136
+ }
137
+ None => Err ( HeaderError :: MissingKey ( key. to_string ( ) ) ) ,
138
+ }
139
+ }
140
+
116
141
fn uri_to_spec ( uri : & Uri ) -> Result < ResourceSpec , SpecParseError > {
117
142
lazy_static ! {
118
143
static ref RE_INDEX : Regex = Regex :: new( r"^/index(/?)$" ) . unwrap( ) ;
@@ -278,6 +303,20 @@ async fn get_operations_from_content_endpoint(
278
303
Ok ( fp)
279
304
}
280
305
306
+ #[ derive( Debug , Error ) ]
307
+ enum ResponseError {
308
+ #[ error( "{0:?}" ) ]
309
+ HeaderError ( #[ from] HeaderError ) ,
310
+ #[ error( "{0:?}" ) ]
311
+ IoError ( #[ from] std:: io:: Error ) ,
312
+ #[ error( "{0:?}" ) ]
313
+ SerdeError ( #[ from] serde_json:: Error ) ,
314
+ #[ error( "{0:?}" ) ]
315
+ StartIndexError ( #[ from] StartIndexError ) ,
316
+ #[ error( "{0:?}" ) ]
317
+ SearchError ( #[ from] SearchError ) ,
318
+ }
319
+
281
320
fn add_to_duplicates ( duplicates : & mut HashMap < usize , usize > , id1 : usize , id2 : usize ) {
282
321
if id1 < id2 {
283
322
duplicates. insert ( id1, id2) ;
@@ -293,15 +332,12 @@ impl Service {
293
332
self . tasks . write ( ) . await . insert ( task_id, status) ;
294
333
}
295
334
296
- async fn get_index ( & self , index_id : & str ) -> Option < Arc < HnswIndex > > {
335
+ async fn get_index ( & self , index_id : & str ) -> io :: Result < Arc < HnswIndex > > {
297
336
if let Some ( hnsw) = self . indexes . read ( ) . await . get ( index_id) {
298
- Some ( hnsw) . cloned ( )
337
+ Ok ( hnsw) . cloned ( )
299
338
} else {
300
339
let mut path = self . path . clone ( ) ;
301
- match deserialize_index ( & mut path, index_id, & self . vector_store ) {
302
- Ok ( res) => Some ( res. into ( ) ) ,
303
- Err ( _) => None ,
304
- }
340
+ Ok ( deserialize_index ( & mut path, index_id, & self . vector_store ) ?. into ( ) )
305
341
}
306
342
}
307
343
@@ -416,24 +452,16 @@ impl Service {
416
452
) -> Result < ( ) , AssignIndexError > {
417
453
let source_name = create_index_name ( & domain, & source_commit) ;
418
454
let target_name = create_index_name ( & domain, & target_commit) ;
419
-
420
- if self . get_index ( & target_name) . await . is_some ( ) {
421
- return Err ( AssignIndexError :: TargetCommitAlreadyHasIndex ) ;
422
- }
423
- if let Some ( index) = self . get_index ( & source_name) . await {
424
- let mut indexes = self . indexes . write ( ) . await ;
425
- indexes. insert ( target_name. clone ( ) , index. clone ( ) ) ;
426
-
427
- std:: mem:: drop ( indexes) ;
428
- tokio:: task:: block_in_place ( move || {
429
- let path = self . path . clone ( ) ;
430
- serialize_index ( path, & target_name, ( * index) . clone ( ) ) . unwrap ( ) ;
431
- } ) ;
432
-
433
- Ok ( ( ) )
434
- } else {
435
- Err ( AssignIndexError :: SourceCommitNotFound )
436
- }
455
+ self . get_index ( & target_name) . await ?;
456
+ let index = self . get_index ( & source_name) . await ?;
457
+ let mut indexes = self . indexes . write ( ) . await ;
458
+ indexes. insert ( target_name. clone ( ) , index. clone ( ) ) ;
459
+ std:: mem:: drop ( indexes) ;
460
+ tokio:: task:: block_in_place ( move || {
461
+ let path = self . path . clone ( ) ;
462
+ serialize_index ( path, & target_name, ( * index) . clone ( ) ) . unwrap ( ) ;
463
+ } ) ;
464
+ Ok ( ( ) )
437
465
}
438
466
439
467
async fn process_operation_chunks (
@@ -475,6 +503,20 @@ impl Service {
475
503
( id, hnsw)
476
504
}
477
505
506
+ async fn get_start_index (
507
+ self : Arc < Self > ,
508
+ req : Request < Body > ,
509
+ domain : String ,
510
+ commit : String ,
511
+ previous : Option < String > ,
512
+ ) -> Result < String , ResponseError > {
513
+ let task_id = Service :: generate_task ( ) ;
514
+ let api_key = get_header_value ( req. headers ( ) , "VECTORLINK_EMBEDDING_API_KEY" ) ?;
515
+ self . set_task_status ( task_id. clone ( ) , TaskStatus :: Pending ( 0.0 ) ) ;
516
+ self . start_indexing ( domain, commit, previous, task_id. clone ( ) , api_key) ?;
517
+ Ok ( task_id)
518
+ }
519
+
478
520
async fn get ( self : Arc < Self > , req : Request < Body > ) -> Result < Response < Body > , Infallible > {
479
521
let uri = req. uri ( ) ;
480
522
match dbg ! ( uri_to_spec( uri) ) {
@@ -483,37 +525,8 @@ impl Service {
483
525
commit,
484
526
previous,
485
527
} ) => {
486
- let task_id = Service :: generate_task ( ) ;
487
- let headers = req. headers ( ) ;
488
- let openai_key = headers. get ( "VECTORLINK_EMBEDDING_API_KEY" ) ;
489
- match openai_key {
490
- Some ( openai_key) => {
491
- let openai_key = String :: from_utf8 ( openai_key. as_bytes ( ) . to_vec ( ) ) . unwrap ( ) ;
492
- self . set_task_status ( task_id. clone ( ) , TaskStatus :: Pending ( 0.0 ) )
493
- . await ;
494
- match self . start_indexing (
495
- domain,
496
- commit,
497
- previous,
498
- task_id. clone ( ) ,
499
- openai_key,
500
- ) {
501
- Ok ( ( ) ) => Ok ( Response :: builder ( ) . body ( task_id. into ( ) ) . unwrap ( ) ) ,
502
- Err ( e) => Ok ( Response :: builder ( )
503
- . status ( 400 )
504
- . body ( e. to_string ( ) . into ( ) )
505
- . unwrap ( ) ) ,
506
- }
507
- }
508
- None => Ok ( Response :: builder ( )
509
- . status ( 400 )
510
- . body (
511
- "No API key supplied in header (VECTORLINK_EMBEDDING_API_KEY)"
512
- . to_string ( )
513
- . into ( ) ,
514
- )
515
- . unwrap ( ) ) ,
516
- }
528
+ let result = self . get_start_index ( req, domain, commit, previous) . await ;
529
+ fun_name ( result)
517
530
}
518
531
Ok ( ResourceSpec :: AssignIndex {
519
532
domain,
@@ -553,27 +566,13 @@ impl Service {
553
566
commit,
554
567
threshold,
555
568
} ) => {
556
- let index_id = create_index_name ( & domain, & commit) ;
557
- // if None, then return 404
558
- let hnsw = self . get_index ( & index_id) . await . unwrap ( ) ;
559
- let mut duplicates: HashMap < usize , usize > = HashMap :: new ( ) ;
560
- let elts = hnsw. layer_len ( 0 ) ;
561
- for i in 0 ..elts {
562
- let current_point = & hnsw. feature ( i) ;
563
- let results = search ( current_point, 2 , & hnsw) . unwrap ( ) ;
564
- for result in results. iter ( ) {
565
- if f32:: from_bits ( result. distance ( ) ) < threshold {
566
- add_to_duplicates ( & mut duplicates, i, result. internal_id ( ) )
567
- }
568
- }
569
+ let result = self
570
+ . get_duplicate_candidates ( domain, commit, threshold)
571
+ . await ;
572
+ match result {
573
+ Ok ( result) => todo ! ( ) ,
574
+ Err ( e) => todo ! ( ) ,
569
575
}
570
- let mut v: Vec < ( & str , & str ) > = duplicates
571
- . into_iter ( )
572
- . map ( |( i, j) | ( hnsw. feature ( i) . id ( ) , hnsw. feature ( j) . id ( ) ) )
573
- . collect ( ) ;
574
- Ok ( Response :: builder ( )
575
- . body ( serde_json:: to_string ( & v) . unwrap ( ) . into ( ) )
576
- . unwrap ( ) )
577
576
}
578
577
Ok ( ResourceSpec :: Similar {
579
578
domain,
@@ -618,6 +617,34 @@ impl Service {
618
617
}
619
618
}
620
619
620
+ async fn get_duplicate_candidates (
621
+ self : Arc < Self > ,
622
+ domain : String ,
623
+ commit : String ,
624
+ threshold : f32 ,
625
+ ) -> Result < String , ResponseError > {
626
+ let index_id = create_index_name ( & domain, & commit) ;
627
+ // if None, then return 404
628
+ let hnsw = self . get_index ( & index_id) . await ?;
629
+ let mut duplicates: HashMap < usize , usize > = HashMap :: new ( ) ;
630
+ let elts = hnsw. layer_len ( 0 ) ;
631
+ for i in 0 ..elts {
632
+ let current_point = & hnsw. feature ( i) ;
633
+ let results = search ( current_point, 2 , & hnsw) ?;
634
+ for result in results. iter ( ) {
635
+ if f32:: from_bits ( result. distance ( ) ) < threshold {
636
+ add_to_duplicates ( & mut duplicates, i, result. internal_id ( ) )
637
+ }
638
+ }
639
+ }
640
+ let mut v: Vec < ( & str , & str ) > = duplicates
641
+ . into_iter ( )
642
+ . map ( |( i, j) | ( hnsw. feature ( i) . id ( ) , hnsw. feature ( j) . id ( ) ) )
643
+ . collect ( ) ;
644
+ let result = serde_json:: to_string ( & v) ?;
645
+ Ok ( result)
646
+ }
647
+
621
648
async fn post ( & self , req : Request < Body > ) -> Result < Response < Body > , Infallible > {
622
649
let uri = req. uri ( ) ;
623
650
match uri_to_spec ( uri) {
@@ -626,28 +653,63 @@ impl Service {
626
653
commit,
627
654
count,
628
655
} ) => {
629
- let body_bytes = hyper:: body:: to_bytes ( req. into_body ( ) ) . await . unwrap ( ) ;
656
+ let headers = req. headers ( ) . clone ( ) ;
657
+ let body = req. into_body ( ) ;
658
+ let body_bytes = hyper:: body:: to_bytes ( body) . await . unwrap ( ) ;
630
659
let q = String :: from_utf8 ( body_bytes. to_vec ( ) ) . unwrap ( ) ;
631
- let vec = Box :: new ( ( embeddings_for ( & self . api_key , & [ q] ) . await . unwrap ( ) ) [ 0 ] ) ;
632
- let qp = Point :: Mem { vec } ;
633
- let index_id = create_index_name ( & domain, & commit) ;
634
- // if None, then return 404
635
- let hnsw = self . get_index ( & index_id) . await . unwrap ( ) ;
636
- let res = search ( & qp, count, & hnsw) . unwrap ( ) ;
637
- let ids: Vec < QueryResult > = res
638
- . iter ( )
639
- . map ( |p| QueryResult {
640
- id : p. id ( ) . to_string ( ) ,
641
- distance : f32:: from_bits ( p. distance ( ) ) ,
642
- } )
643
- . collect ( ) ;
644
- let s = serde_json:: to_string ( & ids) . unwrap ( ) ;
645
- Ok ( Response :: builder ( ) . body ( s. into ( ) ) . unwrap ( ) )
660
+ let api_key = get_header_value ( & headers, "VECTORLINK_EMBEDDING_API_KEY" ) ;
661
+ let result = self . index_response ( api_key, q, domain, commit, count) . await ;
662
+ match result {
663
+ Ok ( body) => Ok ( body) ,
664
+ Err ( e) => Ok ( Response :: builder ( )
665
+ . status ( StatusCode :: NOT_FOUND )
666
+ . body ( e. to_string ( ) . into ( ) )
667
+ . unwrap ( ) ) ,
668
+ }
646
669
}
647
670
Ok ( _) => todo ! ( ) ,
648
- Err ( _) => todo ! ( ) ,
671
+ Err ( e) => Ok ( Response :: builder ( )
672
+ . status ( StatusCode :: NOT_FOUND )
673
+ . body ( e. to_string ( ) . into ( ) )
674
+ . unwrap ( ) ) ,
649
675
}
650
676
}
677
+
678
+ async fn index_response (
679
+ & self ,
680
+ api_key : Result < String , HeaderError > ,
681
+ q : String ,
682
+ domain : String ,
683
+ commit : String ,
684
+ count : usize ,
685
+ ) -> Result < Response < Body > , ResponseError > {
686
+ let api_key = api_key?;
687
+ let vec = Box :: new ( ( embeddings_for ( & self . api_key , & [ q] ) . await . unwrap ( ) ) [ 0 ] ) ;
688
+ let qp = Point :: Mem { vec } ;
689
+ let index_id = create_index_name ( & domain, & commit) ;
690
+ // if None, then return 404
691
+ let hnsw = self . get_index ( & index_id) . await ?;
692
+ let res = search ( & qp, count, & hnsw) . unwrap ( ) ;
693
+ let ids: Vec < QueryResult > = res
694
+ . iter ( )
695
+ . map ( |p| QueryResult {
696
+ id : p. id ( ) . to_string ( ) ,
697
+ distance : f32:: from_bits ( p. distance ( ) ) ,
698
+ } )
699
+ . collect ( ) ;
700
+ let s = serde_json:: to_string ( & ids) ?;
701
+ Ok ( Response :: builder ( ) . body ( s. into ( ) ) . unwrap ( ) )
702
+ }
703
+ }
704
+
705
+ fn fun_name ( result : Result < String , ResponseError > ) -> Result < Response < Body > , Infallible > {
706
+ match result {
707
+ Ok ( task_id) => Ok ( Response :: builder ( ) . body ( task_id. into ( ) ) . unwrap ( ) ) ,
708
+ Err ( e) => Ok ( Response :: builder ( )
709
+ . status ( 400 )
710
+ . body ( e. to_string ( ) . into ( ) )
711
+ . unwrap ( ) ) ,
712
+ }
651
713
}
652
714
653
715
#[ derive( Debug , Error ) ]
0 commit comments