Skip to content

Commit a3aab4f

Browse files
tjgreen42cevian
andauthored
Fixes for poor recall with label filters (#225)
This PR fixes a pair of issues that were observed in benchmarks to cause poor recall with label filtering in certain cases: - missing check to preserve edges in pruning containing label information not covered by other edges (cf. line 6 of Algorithm 3 from Filtered DiskANN paper) - missing logic to suppress label-based checks when inserting into the default start node (resulting in poor recall when no label filter condition was present) Testing: - new sanity check unit test - validated benchmark runs to be reported in an upcoming blog post --------- Signed-off-by: tjgreen42 <tj@timescale.com> Co-authored-by: Matvey Arye <cevian@gmail.com>
1 parent a6c9721 commit a3aab4f

File tree

10 files changed

+464
-45
lines changed

10 files changed

+464
-45
lines changed

pgvectorscale/src/access_method/build.rs

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ unsafe fn aminsert_internal(
178178
}
179179
let vec = vec.unwrap();
180180

181+
// PgVector is not cloneable, but in some cases we need a second copy of it
182+
// for the insert. This is a bit of a hack to get that second copy.
183+
let spare_vec = LabeledVector::from_datums(values, isnull, &meta_page).unwrap();
184+
181185
let heap_pointer = ItemPointer::with_item_pointer_data(*heap_tid);
182186
let mut storage = meta_page.get_storage_type();
183187
let mut stats = InsertStats::new();
@@ -189,10 +193,12 @@ unsafe fn aminsert_internal(
189193
&heap_relation,
190194
meta_page.get_distance_function(),
191195
);
196+
assert!(vec.labels().is_none());
192197
insert_storage(
193198
&plain,
194199
&index_relation,
195200
vec,
201+
spare_vec,
196202
heap_pointer,
197203
&mut meta_page,
198204
&mut stats,
@@ -209,6 +215,7 @@ unsafe fn aminsert_internal(
209215
&bq,
210216
&index_relation,
211217
vec,
218+
spare_vec,
212219
heap_pointer,
213220
&mut meta_page,
214221
&mut stats,
@@ -222,6 +229,7 @@ unsafe fn insert_storage<S: Storage>(
222229
storage: &S,
223230
index_relation: &PgRelation,
224231
vector: LabeledVector,
232+
spare_vector: LabeledVector,
225233
heap_pointer: ItemPointer,
226234
meta_page: &mut MetaPage,
227235
stats: &mut InsertStats,
@@ -237,7 +245,14 @@ unsafe fn insert_storage<S: Storage>(
237245
);
238246

239247
let mut graph = Graph::new(GraphNeighborStore::Disk, meta_page);
240-
graph.insert(index_relation, index_pointer, vector, storage, stats);
248+
graph.insert(
249+
index_relation,
250+
index_pointer,
251+
vector,
252+
spare_vector,
253+
storage,
254+
stats,
255+
);
241256
}
242257

243258
#[pg_guard]
@@ -350,14 +365,14 @@ fn finalize_index_build<S: Storage>(
350365
) -> usize {
351366
match state.graph.get_neighbor_store() {
352367
GraphNeighborStore::Builder(builder) => {
353-
for (&index_pointer, neighbors) in builder.iter() {
368+
for (&index_pointer, (labels, neighbors)) in builder.iter() {
354369
write_stats.num_nodes += 1;
355370
let prune_neighbors;
356371
let neighbors =
357372
if neighbors.len() > state.graph.get_meta_page().get_num_neighbors() as _ {
358373
//OPT: get rid of this clone
359374
prune_neighbors = state.graph.prune_neighbors(
360-
index_pointer,
375+
labels,
361376
neighbors.clone(),
362377
storage,
363378
&mut write_stats.prune_stats,
@@ -459,13 +474,33 @@ unsafe extern "C" fn build_callback(
459474
StorageBuildState::SbqSpeedup(bq, state) => {
460475
let vec = LabeledVector::from_datums(values, isnull, state.graph.get_meta_page());
461476
if let Some(vec) = vec {
462-
build_callback_memory_wrapper(&index_relation, heap_pointer, vec, state, *bq);
477+
let spare_vec =
478+
LabeledVector::from_datums(values, isnull, state.graph.get_meta_page())
479+
.unwrap();
480+
build_callback_memory_wrapper(
481+
&index_relation,
482+
heap_pointer,
483+
vec,
484+
spare_vec,
485+
state,
486+
*bq,
487+
);
463488
}
464489
}
465490
StorageBuildState::Plain(plain, state) => {
466491
let vec = LabeledVector::from_datums(values, isnull, state.graph.get_meta_page());
467492
if let Some(vec) = vec {
468-
build_callback_memory_wrapper(&index_relation, heap_pointer, vec, state, *plain);
493+
let spare_vec =
494+
LabeledVector::from_datums(values, isnull, state.graph.get_meta_page())
495+
.unwrap();
496+
build_callback_memory_wrapper(
497+
&index_relation,
498+
heap_pointer,
499+
vec,
500+
spare_vec,
501+
state,
502+
*plain,
503+
);
469504
}
470505
}
471506
}
@@ -476,12 +511,13 @@ unsafe fn build_callback_memory_wrapper<S: Storage>(
476511
index: &PgRelation,
477512
heap_pointer: ItemPointer,
478513
vector: LabeledVector,
514+
spare_vector: LabeledVector,
479515
state: &mut BuildState,
480516
storage: &mut S,
481517
) {
482518
let mut old_context = state.memcxt.set_as_current();
483519

484-
build_callback_internal(index, heap_pointer, vector, state, storage);
520+
build_callback_internal(index, heap_pointer, vector, spare_vector, state, storage);
485521

486522
old_context.set_as_current();
487523
state.memcxt.reset();
@@ -492,6 +528,7 @@ fn build_callback_internal<S: Storage>(
492528
index: &PgRelation,
493529
heap_pointer: ItemPointer,
494530
vector: LabeledVector,
531+
spare_vector: LabeledVector,
495532
state: &mut BuildState,
496533
storage: &mut S,
497534
) {
@@ -520,9 +557,14 @@ fn build_callback_internal<S: Storage>(
520557
&mut state.stats,
521558
);
522559

523-
state
524-
.graph
525-
.insert(index, index_pointer, vector, storage, &mut state.stats);
560+
state.graph.insert(
561+
index,
562+
index_pointer,
563+
vector,
564+
spare_vector,
565+
storage,
566+
&mut state.stats,
567+
);
526568
}
527569

528570
const BUILD_PHASE_TRAINING: i64 = 0;

pgvectorscale/src/access_method/graph.rs

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ impl<'a> Graph<'a> {
202202
&mut self,
203203
storage: &S,
204204
neighbors_of: ItemPointer,
205+
labels: Option<&LabelSet>,
205206
additional_neighbors: Vec<NeighborWithDistance>,
206207
stats: &mut PruneNeighborStats,
207208
) -> (bool, Vec<NeighborWithDistance>) {
@@ -237,7 +238,7 @@ impl<'a> Graph<'a> {
237238

238239
let (pruned, new_neighbors) =
239240
if candidates.len() > self.neighbor_store.max_neighbors(self.get_meta_page()) {
240-
let new_list = self.prune_neighbors(neighbors_of, candidates, storage, stats);
241+
let new_list = self.prune_neighbors(labels, candidates, storage, stats);
241242
(true, new_list)
242243
} else {
243244
(false, candidates)
@@ -247,6 +248,7 @@ impl<'a> Graph<'a> {
247248
storage,
248249
self.meta_page,
249250
neighbors_of,
251+
labels.cloned(),
250252
new_neighbors.clone(),
251253
stats,
252254
);
@@ -278,6 +280,7 @@ impl<'a> Graph<'a> {
278280
&self,
279281
index_pointer: IndexPointer,
280282
query: LabeledVector,
283+
no_filter: bool,
281284
storage: &S,
282285
stats: &mut GreedySearchStats,
283286
) -> HashSet<NeighborWithDistance> {
@@ -286,7 +289,13 @@ impl<'a> Graph<'a> {
286289
//no nodes in the graph
287290
return HashSet::with_capacity(0);
288291
}
289-
let start_nodes = start_nodes.unwrap().get_for_node(query.labels());
292+
293+
let start_nodes = if no_filter {
294+
start_nodes.unwrap().get_for_node(None)
295+
} else {
296+
start_nodes.unwrap().get_for_node(query.labels())
297+
};
298+
290299
let dm = storage.get_query_distance_measure(query);
291300
let search_list_size = self.meta_page.get_search_list_size_for_build() as usize;
292301
let mut l = ListSearchResult::new(
@@ -299,7 +308,13 @@ impl<'a> Graph<'a> {
299308
storage,
300309
);
301310
let mut visited_nodes = HashSet::with_capacity(search_list_size);
302-
self.greedy_search_iterate(&mut l, search_list_size, Some(&mut visited_nodes), storage);
311+
self.greedy_search_iterate(
312+
&mut l,
313+
search_list_size,
314+
no_filter,
315+
Some(&mut visited_nodes),
316+
storage,
317+
);
303318
stats.combine(&l.stats);
304319
visited_nodes
305320
}
@@ -336,6 +351,7 @@ impl<'a> Graph<'a> {
336351
&self,
337352
lsr: &mut ListSearchResult<S::QueryDistanceMeasure, S::LSNPrivateData>,
338353
visit_n_closest: usize,
354+
no_filter: bool,
339355
mut visited_nodes: Option<&mut HashSet<NeighborWithDistance>>,
340356
storage: &S,
341357
) {
@@ -352,7 +368,7 @@ impl<'a> Graph<'a> {
352368
}
353369
}
354370
lsr.stats.record_visit();
355-
storage.visit_lsn(lsr, list_search_entry_idx, &self.neighbor_store);
371+
storage.visit_lsn(lsr, list_search_entry_idx, &self.neighbor_store, no_filter);
356372
}
357373
}
358374

@@ -363,7 +379,7 @@ impl<'a> Graph<'a> {
363379
/// if we save the factors or the distances and add incrementally. Not sure.
364380
pub fn prune_neighbors<S: Storage>(
365381
&self,
366-
_neighbors_of: ItemPointer,
382+
labels: Option<&LabelSet>,
367383
mut candidates: Vec<NeighborWithDistance>,
368384
storage: &S,
369385
stats: &mut PruneNeighborStats,
@@ -393,6 +409,7 @@ impl<'a> Graph<'a> {
393409
if results.len() >= self.get_meta_page().get_num_neighbors() as _ {
394410
return results;
395411
}
412+
396413
if max_factors[i] > alpha {
397414
continue;
398415
}
@@ -412,13 +429,27 @@ impl<'a> Graph<'a> {
412429
)
413430
};
414431

432+
// TODO: optimization: precompute intersection of `labels` and `existing_neighbor.get_labels()`
433+
// and use this for the `contains_intersection` check inside the loop
434+
415435
//go thru the other candidates (tail of the list)
416436
for (j, candidate_neighbor) in candidates.iter().enumerate().skip(i + 1) {
417437
//has it been completely excluded?
418438
if max_factors[j] > max_alpha {
419439
continue;
420440
}
421441

442+
// Does it contain essential labels?
443+
if let Some(labels) = labels {
444+
if !existing_neighbor
445+
.get_labels()
446+
.unwrap()
447+
.contains_intersection(candidate_neighbor.get_labels().unwrap(), labels)
448+
{
449+
continue;
450+
}
451+
}
452+
422453
let raw_distance_between_candidate_and_existing_neighbor = unsafe {
423454
dist_state
424455
.get_distance(candidate_neighbor.get_index_pointer_to_neighbor(), stats)
@@ -468,6 +499,7 @@ impl<'a> Graph<'a> {
468499
storage,
469500
self.meta_page,
470501
index_pointer,
502+
vec.labels().cloned(),
471503
Vec::<NeighborWithDistance>::with_capacity(
472504
self.neighbor_store.max_neighbors(self.meta_page) as _,
473505
),
@@ -494,12 +526,11 @@ impl<'a> Graph<'a> {
494526

495527
/// Check that all nodes of the graph are reachable from the start node(s)
496528
#[allow(dead_code)]
497-
pub fn debug_check_consistency<S: Storage>(
498-
&mut self,
529+
pub fn debug_count_reachable_nodes<S: Storage>(
530+
&self,
499531
storage: &S,
500532
stats: &mut InsertStats,
501-
) -> bool {
502-
let num_nodes = 0; // TODO: calculate from base table?
533+
) -> usize {
503534
if let Some(start_nodes) = self.meta_page.get_start_nodes() {
504535
let mut visited = HashSet::new();
505536
let mut to_visit = start_nodes.get_all_nodes();
@@ -521,9 +552,9 @@ impl<'a> Graph<'a> {
521552
to_visit.push(ip);
522553
}
523554
}
524-
return visited.len() == num_nodes;
555+
return visited.len();
525556
}
526-
num_nodes == 0
557+
0
527558
}
528559

529560
fn debug_format_labels(labels: Option<LabelSet>) -> String {
@@ -612,24 +643,44 @@ digraph G {
612643
index: &PgRelation,
613644
index_pointer: IndexPointer,
614645
vec: LabeledVector,
646+
spare_vec: LabeledVector,
615647
storage: &S,
616648
stats: &mut InsertStats,
617649
) {
618650
self.update_start_nodes(index, index_pointer, &vec, storage, stats);
619651

652+
if vec.labels().is_some() {
653+
// Insert starting from label start nodes and apply label filtering
654+
self.insert_internal(index_pointer, spare_vec, false, storage, stats);
655+
}
656+
657+
// Insert starting from default start node and avoid label filtering
658+
self.insert_internal(index_pointer, vec, true, storage, stats);
659+
}
660+
661+
fn insert_internal<S: Storage>(
662+
&mut self,
663+
index_pointer: IndexPointer,
664+
vec: LabeledVector,
665+
no_filter: bool,
666+
storage: &S,
667+
stats: &mut InsertStats,
668+
) {
620669
let labels = vec.labels().cloned();
621670

622671
#[allow(clippy::mutable_key_type)]
623672
let v = self.greedy_search_for_build(
624673
index_pointer,
625674
vec,
675+
no_filter,
626676
storage,
627677
&mut stats.greedy_search_stats,
628678
);
629679

630680
let (_, neighbor_list) = self.add_neighbors(
631681
storage,
632682
index_pointer,
683+
labels.as_ref(),
633684
v.into_iter().collect(),
634685
&mut stats.prune_neighbor_stats,
635686
);
@@ -641,6 +692,7 @@ digraph G {
641692
let neighbor_contains_new_point = self.update_back_pointer(
642693
neighbor.get_index_pointer_to_neighbor(),
643694
index_pointer,
695+
neighbor.get_labels(),
644696
labels.as_ref(),
645697
neighbor.get_distance_with_tie_break(),
646698
storage,
@@ -664,10 +716,12 @@ digraph G {
664716
}
665717
}
666718

719+
#[allow(clippy::too_many_arguments)]
667720
fn update_back_pointer<S: Storage>(
668721
&mut self,
669722
from: IndexPointer,
670723
to: IndexPointer,
724+
from_labels: Option<&LabelSet>,
671725
to_labels: Option<&LabelSet>,
672726
distance_with_tie_break: &DistanceWithTieBreak,
673727
storage: &S,
@@ -678,7 +732,7 @@ digraph G {
678732
distance_with_tie_break.clone(),
679733
to_labels.cloned(),
680734
)];
681-
let (_pruned, n) = self.add_neighbors(storage, from, new.clone(), prune_stats);
735+
let (_pruned, n) = self.add_neighbors(storage, from, from_labels, new.clone(), prune_stats);
682736
n.contains(&new[0])
683737
}
684738
}

0 commit comments

Comments
 (0)