@@ -202,6 +202,7 @@ impl<'a> Graph<'a> {
202
202
& mut self ,
203
203
storage : & S ,
204
204
neighbors_of : ItemPointer ,
205
+ labels : Option < & LabelSet > ,
205
206
additional_neighbors : Vec < NeighborWithDistance > ,
206
207
stats : & mut PruneNeighborStats ,
207
208
) -> ( bool , Vec < NeighborWithDistance > ) {
@@ -237,7 +238,7 @@ impl<'a> Graph<'a> {
237
238
238
239
let ( pruned, new_neighbors) =
239
240
if candidates. len ( ) > self . neighbor_store . max_neighbors ( self . get_meta_page ( ) ) {
240
- let new_list = self . prune_neighbors ( neighbors_of , candidates, storage, stats) ;
241
+ let new_list = self . prune_neighbors ( labels , candidates, storage, stats) ;
241
242
( true , new_list)
242
243
} else {
243
244
( false , candidates)
@@ -247,6 +248,7 @@ impl<'a> Graph<'a> {
247
248
storage,
248
249
self . meta_page ,
249
250
neighbors_of,
251
+ labels. cloned ( ) ,
250
252
new_neighbors. clone ( ) ,
251
253
stats,
252
254
) ;
@@ -278,6 +280,7 @@ impl<'a> Graph<'a> {
278
280
& self ,
279
281
index_pointer : IndexPointer ,
280
282
query : LabeledVector ,
283
+ no_filter : bool ,
281
284
storage : & S ,
282
285
stats : & mut GreedySearchStats ,
283
286
) -> HashSet < NeighborWithDistance > {
@@ -286,7 +289,13 @@ impl<'a> Graph<'a> {
286
289
//no nodes in the graph
287
290
return HashSet :: with_capacity ( 0 ) ;
288
291
}
289
- let start_nodes = start_nodes. unwrap ( ) . get_for_node ( query. labels ( ) ) ;
292
+
293
+ let start_nodes = if no_filter {
294
+ start_nodes. unwrap ( ) . get_for_node ( None )
295
+ } else {
296
+ start_nodes. unwrap ( ) . get_for_node ( query. labels ( ) )
297
+ } ;
298
+
290
299
let dm = storage. get_query_distance_measure ( query) ;
291
300
let search_list_size = self . meta_page . get_search_list_size_for_build ( ) as usize ;
292
301
let mut l = ListSearchResult :: new (
@@ -299,7 +308,13 @@ impl<'a> Graph<'a> {
299
308
storage,
300
309
) ;
301
310
let mut visited_nodes = HashSet :: with_capacity ( search_list_size) ;
302
- self . greedy_search_iterate ( & mut l, search_list_size, Some ( & mut visited_nodes) , storage) ;
311
+ self . greedy_search_iterate (
312
+ & mut l,
313
+ search_list_size,
314
+ no_filter,
315
+ Some ( & mut visited_nodes) ,
316
+ storage,
317
+ ) ;
303
318
stats. combine ( & l. stats ) ;
304
319
visited_nodes
305
320
}
@@ -336,6 +351,7 @@ impl<'a> Graph<'a> {
336
351
& self ,
337
352
lsr : & mut ListSearchResult < S :: QueryDistanceMeasure , S :: LSNPrivateData > ,
338
353
visit_n_closest : usize ,
354
+ no_filter : bool ,
339
355
mut visited_nodes : Option < & mut HashSet < NeighborWithDistance > > ,
340
356
storage : & S ,
341
357
) {
@@ -352,7 +368,7 @@ impl<'a> Graph<'a> {
352
368
}
353
369
}
354
370
lsr. stats . record_visit ( ) ;
355
- storage. visit_lsn ( lsr, list_search_entry_idx, & self . neighbor_store ) ;
371
+ storage. visit_lsn ( lsr, list_search_entry_idx, & self . neighbor_store , no_filter ) ;
356
372
}
357
373
}
358
374
@@ -363,7 +379,7 @@ impl<'a> Graph<'a> {
363
379
/// if we save the factors or the distances and add incrementally. Not sure.
364
380
pub fn prune_neighbors < S : Storage > (
365
381
& self ,
366
- _neighbors_of : ItemPointer ,
382
+ labels : Option < & LabelSet > ,
367
383
mut candidates : Vec < NeighborWithDistance > ,
368
384
storage : & S ,
369
385
stats : & mut PruneNeighborStats ,
@@ -393,6 +409,7 @@ impl<'a> Graph<'a> {
393
409
if results. len ( ) >= self . get_meta_page ( ) . get_num_neighbors ( ) as _ {
394
410
return results;
395
411
}
412
+
396
413
if max_factors[ i] > alpha {
397
414
continue ;
398
415
}
@@ -412,13 +429,27 @@ impl<'a> Graph<'a> {
412
429
)
413
430
} ;
414
431
432
+ // TODO: optimization: precompute intersection of `labels` and `existing_neighbor.get_labels()`
433
+ // and use this for the `contains_intersection` check inside the loop
434
+
415
435
//go thru the other candidates (tail of the list)
416
436
for ( j, candidate_neighbor) in candidates. iter ( ) . enumerate ( ) . skip ( i + 1 ) {
417
437
//has it been completely excluded?
418
438
if max_factors[ j] > max_alpha {
419
439
continue ;
420
440
}
421
441
442
+ // Does it contain essential labels?
443
+ if let Some ( labels) = labels {
444
+ if !existing_neighbor
445
+ . get_labels ( )
446
+ . unwrap ( )
447
+ . contains_intersection ( candidate_neighbor. get_labels ( ) . unwrap ( ) , labels)
448
+ {
449
+ continue ;
450
+ }
451
+ }
452
+
422
453
let raw_distance_between_candidate_and_existing_neighbor = unsafe {
423
454
dist_state
424
455
. get_distance ( candidate_neighbor. get_index_pointer_to_neighbor ( ) , stats)
@@ -468,6 +499,7 @@ impl<'a> Graph<'a> {
468
499
storage,
469
500
self . meta_page ,
470
501
index_pointer,
502
+ vec. labels ( ) . cloned ( ) ,
471
503
Vec :: < NeighborWithDistance > :: with_capacity (
472
504
self . neighbor_store . max_neighbors ( self . meta_page ) as _ ,
473
505
) ,
@@ -494,12 +526,11 @@ impl<'a> Graph<'a> {
494
526
495
527
/// Check that all nodes of the graph are reachable from the start node(s)
496
528
#[ allow( dead_code) ]
497
- pub fn debug_check_consistency < S : Storage > (
498
- & mut self ,
529
+ pub fn debug_count_reachable_nodes < S : Storage > (
530
+ & self ,
499
531
storage : & S ,
500
532
stats : & mut InsertStats ,
501
- ) -> bool {
502
- let num_nodes = 0 ; // TODO: calculate from base table?
533
+ ) -> usize {
503
534
if let Some ( start_nodes) = self . meta_page . get_start_nodes ( ) {
504
535
let mut visited = HashSet :: new ( ) ;
505
536
let mut to_visit = start_nodes. get_all_nodes ( ) ;
@@ -521,9 +552,9 @@ impl<'a> Graph<'a> {
521
552
to_visit. push ( ip) ;
522
553
}
523
554
}
524
- return visited. len ( ) == num_nodes ;
555
+ return visited. len ( ) ;
525
556
}
526
- num_nodes == 0
557
+ 0
527
558
}
528
559
529
560
fn debug_format_labels ( labels : Option < LabelSet > ) -> String {
@@ -612,24 +643,44 @@ digraph G {
612
643
index : & PgRelation ,
613
644
index_pointer : IndexPointer ,
614
645
vec : LabeledVector ,
646
+ spare_vec : LabeledVector ,
615
647
storage : & S ,
616
648
stats : & mut InsertStats ,
617
649
) {
618
650
self . update_start_nodes ( index, index_pointer, & vec, storage, stats) ;
619
651
652
+ if vec. labels ( ) . is_some ( ) {
653
+ // Insert starting from label start nodes and apply label filtering
654
+ self . insert_internal ( index_pointer, spare_vec, false , storage, stats) ;
655
+ }
656
+
657
+ // Insert starting from default start node and avoid label filtering
658
+ self . insert_internal ( index_pointer, vec, true , storage, stats) ;
659
+ }
660
+
661
+ fn insert_internal < S : Storage > (
662
+ & mut self ,
663
+ index_pointer : IndexPointer ,
664
+ vec : LabeledVector ,
665
+ no_filter : bool ,
666
+ storage : & S ,
667
+ stats : & mut InsertStats ,
668
+ ) {
620
669
let labels = vec. labels ( ) . cloned ( ) ;
621
670
622
671
#[ allow( clippy:: mutable_key_type) ]
623
672
let v = self . greedy_search_for_build (
624
673
index_pointer,
625
674
vec,
675
+ no_filter,
626
676
storage,
627
677
& mut stats. greedy_search_stats ,
628
678
) ;
629
679
630
680
let ( _, neighbor_list) = self . add_neighbors (
631
681
storage,
632
682
index_pointer,
683
+ labels. as_ref ( ) ,
633
684
v. into_iter ( ) . collect ( ) ,
634
685
& mut stats. prune_neighbor_stats ,
635
686
) ;
@@ -641,6 +692,7 @@ digraph G {
641
692
let neighbor_contains_new_point = self . update_back_pointer (
642
693
neighbor. get_index_pointer_to_neighbor ( ) ,
643
694
index_pointer,
695
+ neighbor. get_labels ( ) ,
644
696
labels. as_ref ( ) ,
645
697
neighbor. get_distance_with_tie_break ( ) ,
646
698
storage,
@@ -664,10 +716,12 @@ digraph G {
664
716
}
665
717
}
666
718
719
+ #[ allow( clippy:: too_many_arguments) ]
667
720
fn update_back_pointer < S : Storage > (
668
721
& mut self ,
669
722
from : IndexPointer ,
670
723
to : IndexPointer ,
724
+ from_labels : Option < & LabelSet > ,
671
725
to_labels : Option < & LabelSet > ,
672
726
distance_with_tie_break : & DistanceWithTieBreak ,
673
727
storage : & S ,
@@ -678,7 +732,7 @@ digraph G {
678
732
distance_with_tie_break. clone( ) ,
679
733
to_labels. cloned( ) ,
680
734
) ] ;
681
- let ( _pruned, n) = self . add_neighbors ( storage, from, new. clone ( ) , prune_stats) ;
735
+ let ( _pruned, n) = self . add_neighbors ( storage, from, from_labels , new. clone ( ) , prune_stats) ;
682
736
n. contains ( & new[ 0 ] )
683
737
}
684
738
}
0 commit comments