1
1
use std:: time:: Instant ;
2
2
3
- use pgrx:: pg_sys:: { pgstat_progress_update_param, AsPgCStr } ;
3
+ use pg_sys:: { FunctionCall0Coll , InvalidOid } ;
4
+ use pgrx:: pg_sys:: { index_getprocinfo, pgstat_progress_update_param, AsPgCStr } ;
4
5
use pgrx:: * ;
5
6
7
+ use crate :: access_method:: distance:: DistanceType ;
6
8
use crate :: access_method:: graph:: Graph ;
7
9
use crate :: access_method:: graph_neighbor_store:: GraphNeighborStore ;
8
10
use crate :: access_method:: options:: TSVIndexOptions ;
9
11
use crate :: access_method:: pg_vector:: PgVector ;
10
12
use crate :: access_method:: stats:: { InsertStats , WriteStats } ;
11
13
14
+ use crate :: access_method:: DISKANN_DISTANCE_TYPE_PROC ;
12
15
use crate :: util:: page:: PageType ;
13
16
use crate :: util:: tape:: Tape ;
14
17
use crate :: util:: * ;
@@ -79,7 +82,18 @@ pub extern "C" fn ambuild(
79
82
80
83
let dimensions = index_relation. tuple_desc ( ) . get ( 0 ) . unwrap ( ) . atttypmod ;
81
84
assert ! ( dimensions > 0 && dimensions <= 2000 ) ;
82
- let meta_page = unsafe { MetaPage :: create ( & index_relation, dimensions as _ , opt) } ;
85
+
86
+ let distance_type = unsafe {
87
+ let fmgr_info = index_getprocinfo ( indexrel, 1 , DISKANN_DISTANCE_TYPE_PROC ) ;
88
+ if fmgr_info == std:: ptr:: null_mut ( ) {
89
+ error ! ( "No distance type function found for index" ) ;
90
+ }
91
+ let result = FunctionCall0Coll ( fmgr_info, InvalidOid ) . value ( ) as u16 ;
92
+ DistanceType :: from_u16 ( result)
93
+ } ;
94
+
95
+ let meta_page =
96
+ unsafe { MetaPage :: create ( & index_relation, dimensions as _ , distance_type, opt) } ;
83
97
84
98
let ntuples = do_heap_scan ( index_info, & heap_relation, & index_relation, meta_page) ;
85
99
@@ -487,15 +501,19 @@ pub unsafe extern "C" fn ambuildphasename(phasenum: i64) -> *mut ffi::c_char {
487
501
pub mod tests {
488
502
use std:: collections:: HashSet ;
489
503
504
+ use crate :: access_method:: distance:: DistanceType ;
490
505
use pgrx:: * ;
491
506
492
507
//TODO: add test where inserting and querying with vectors that are all the same.
493
508
494
509
#[ cfg( any( test, feature = "pg_test" ) ) ]
495
510
pub unsafe fn test_index_creation_and_accuracy_scaffold (
511
+ distance_type : DistanceType ,
496
512
index_options : & str ,
497
513
name : & str ,
498
514
) -> spi:: Result < ( ) > {
515
+ let operator = distance_type. get_operator ( ) ;
516
+ let operator_class = distance_type. get_operator_class ( ) ;
499
517
let table_name = format ! ( "test_data_icaa_{}" , name) ;
500
518
Spi :: run ( & format ! (
501
519
"CREATE TABLE {table_name} (
@@ -515,7 +533,7 @@ pub mod tests {
515
533
GROUP BY
516
534
i % 300) g;
517
535
518
- CREATE INDEX ON {table_name} USING diskann (embedding) WITH ({index_options});
536
+ CREATE INDEX ON {table_name} USING diskann (embedding {operator_class} ) WITH ({index_options});
519
537
520
538
521
539
SET enable_seqscan = 0;
@@ -525,7 +543,7 @@ pub mod tests {
525
543
FROM
526
544
{table_name}
527
545
ORDER BY
528
- embedding <=> (
546
+ embedding {operator} (
529
547
SELECT
530
548
('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding
531
549
FROM generate_series(1, 1536));" ) ) ?;
@@ -542,7 +560,7 @@ pub mod tests {
542
560
SET enable_seqscan = 0;
543
561
SET enable_indexscan = 1;
544
562
SET diskann.query_search_list_size = 2;
545
- WITH cte as (select * from {table_name} order by embedding <=> $1::vector) SELECT count(*) from cte;
563
+ WITH cte as (select * from {table_name} order by embedding {operator} $1::vector) SELECT count(*) from cte;
546
564
" ,
547
565
) ,
548
566
vec ! [ (
@@ -576,7 +594,7 @@ pub mod tests {
576
594
FROM
577
595
{table_name}
578
596
ORDER BY
579
- embedding <=> (
597
+ embedding {operator} (
580
598
SELECT
581
599
('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding
582
600
FROM generate_series(1, 1536));
@@ -608,7 +626,7 @@ pub mod tests {
608
626
FROM
609
627
{table_name}
610
628
ORDER BY
611
- embedding <=> $1::vector
629
+ embedding {operator} $1::vector
612
630
LIMIT 10
613
631
)
614
632
SELECT array_agg(ctid) from cte;"
@@ -631,7 +649,7 @@ pub mod tests {
631
649
FROM
632
650
{table_name}
633
651
ORDER BY
634
- embedding <=> $1::vector
652
+ embedding {operator} $1::vector
635
653
LIMIT 10
636
654
)
637
655
SELECT array_agg(ctid) from cte;"
@@ -655,7 +673,7 @@ pub mod tests {
655
673
FROM
656
674
{table_name}
657
675
ORDER BY
658
- embedding <=> $1::vector
676
+ embedding {operator} $1::vector
659
677
LIMIT 10
660
678
)
661
679
SELECT array_agg(ctid) from cte;"
@@ -686,7 +704,7 @@ pub mod tests {
686
704
SET enable_seqscan = 0;
687
705
SET enable_indexscan = 1;
688
706
SET diskann.query_search_list_size = 2;
689
- WITH cte as (select * from {table_name} order by embedding <=> $1::vector) SELECT count(*) from cte;
707
+ WITH cte as (select * from {table_name} order by embedding {operator} $1::vector) SELECT count(*) from cte;
690
708
" ,
691
709
) ,
692
710
vec ! [ (
@@ -759,10 +777,14 @@ pub mod tests {
759
777
760
778
#[ cfg( any( test, feature = "pg_test" ) ) ]
761
779
pub unsafe fn test_index_updates (
780
+ distance_type : DistanceType ,
762
781
index_options : & str ,
763
782
expected_cnt : i64 ,
764
783
name : & str ,
765
784
) -> spi:: Result < ( ) > {
785
+ let operator_class = distance_type. get_operator_class ( ) ;
786
+ let operator = distance_type. get_operator ( ) ;
787
+
766
788
let table_name = format ! ( "test_data_index_updates_{}" , name) ;
767
789
Spi :: run ( & format ! (
768
790
"CREATE TABLE {table_name} (
@@ -784,7 +806,7 @@ pub mod tests {
784
806
GROUP BY
785
807
i % {expected_cnt}) g;
786
808
787
- CREATE INDEX ON {table_name} USING diskann (embedding) WITH ({index_options});
809
+ CREATE INDEX ON {table_name} USING diskann (embedding {operator_class} ) WITH ({index_options});
788
810
789
811
790
812
SET enable_seqscan = 0;
@@ -794,7 +816,7 @@ pub mod tests {
794
816
FROM
795
817
{table_name}
796
818
ORDER BY
797
- embedding <=> (
819
+ embedding {operator} (
798
820
SELECT
799
821
('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding
800
822
FROM generate_series(1, 1536));" ) ) ?;
@@ -811,7 +833,7 @@ pub mod tests {
811
833
SET enable_seqscan = 0;
812
834
SET enable_indexscan = 1;
813
835
SET diskann.query_search_list_size = 2;
814
- WITH cte as (select * from {table_name} order by embedding <=> $1::vector) SELECT count(*) from cte;
836
+ WITH cte as (select * from {table_name} order by embedding {operator} $1::vector) SELECT count(*) from cte;
815
837
" ,
816
838
) ,
817
839
vec ! [ (
@@ -850,7 +872,7 @@ pub mod tests {
850
872
SET enable_seqscan = 0;
851
873
SET enable_indexscan = 1;
852
874
SET diskann.query_search_list_size = 2;
853
- WITH cte as (select * from {table_name} order by embedding <=> $1::vector) SELECT count(*) from cte;
875
+ WITH cte as (select * from {table_name} order by embedding {operator} $1::vector) SELECT count(*) from cte;
854
876
" ,
855
877
) ,
856
878
vec ! [ (
0 commit comments