Skip to content

Commit 43e3f5e

Browse files
tjgreen42cevian
andauthored
[Distance metrics M1] support for L2 distance (#156)
This PR adds support for distance metrics besides cosine, beginning with L2. The primitives for this metric were already present (earlier versions of pgvectorscale used L2 rather than cosine), so this PR just adds the operator class-related plumbing. As discussed in the design doc (see references) we follow pgvector syntax. --------- Signed-off-by: tjgreen42 <tj@timescale.com> Co-authored-by: Matvey Arye <cevian@gmail.com>
1 parent 4f86490 commit 43e3f5e

File tree

8 files changed

+296
-53
lines changed

8 files changed

+296
-53
lines changed

pgvectorscale/Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ pg_test = []
2222
[dependencies]
2323
memoffset = "0.9.0"
2424
pgrx = "=0.12.5"
25-
rkyv = { version="0.7.42", features=["validation"]}
26-
simdeez = {version = "1.0.8"}
27-
rand = { version = "0.8", features = [ "small_rng" ] }
25+
rkyv = { version = "0.7.42", features = ["validation"] }
26+
simdeez = { version = "1.0.8" }
27+
rand = { version = "0.8", features = ["small_rng"] }
2828
pgvectorscale_derive = { path = "pgvectorscale_derive" }
2929
semver = "1.0.22"
3030
once_cell = "1.20.1"

pgvectorscale/src/access_method/build.rs

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
use std::time::Instant;
22

3-
use pgrx::pg_sys::{pgstat_progress_update_param, AsPgCStr};
3+
use pg_sys::{FunctionCall0Coll, InvalidOid};
4+
use pgrx::pg_sys::{index_getprocinfo, pgstat_progress_update_param, AsPgCStr};
45
use pgrx::*;
56

7+
use crate::access_method::distance::DistanceType;
68
use crate::access_method::graph::Graph;
79
use crate::access_method::graph_neighbor_store::GraphNeighborStore;
810
use crate::access_method::options::TSVIndexOptions;
911
use crate::access_method::pg_vector::PgVector;
1012
use crate::access_method::stats::{InsertStats, WriteStats};
1113

14+
use crate::access_method::DISKANN_DISTANCE_TYPE_PROC;
1215
use crate::util::page::PageType;
1316
use crate::util::tape::Tape;
1417
use crate::util::*;
@@ -79,7 +82,18 @@ pub extern "C" fn ambuild(
7982

8083
let dimensions = index_relation.tuple_desc().get(0).unwrap().atttypmod;
8184
assert!(dimensions > 0 && dimensions <= 2000);
82-
let meta_page = unsafe { MetaPage::create(&index_relation, dimensions as _, opt) };
85+
86+
let distance_type = unsafe {
87+
let fmgr_info = index_getprocinfo(indexrel, 1, DISKANN_DISTANCE_TYPE_PROC);
88+
if fmgr_info == std::ptr::null_mut() {
89+
error!("No distance type function found for index");
90+
}
91+
let result = FunctionCall0Coll(fmgr_info, InvalidOid).value() as u16;
92+
DistanceType::from_u16(result)
93+
};
94+
95+
let meta_page =
96+
unsafe { MetaPage::create(&index_relation, dimensions as _, distance_type, opt) };
8397

8498
let ntuples = do_heap_scan(index_info, &heap_relation, &index_relation, meta_page);
8599

@@ -487,15 +501,19 @@ pub unsafe extern "C" fn ambuildphasename(phasenum: i64) -> *mut ffi::c_char {
487501
pub mod tests {
488502
use std::collections::HashSet;
489503

504+
use crate::access_method::distance::DistanceType;
490505
use pgrx::*;
491506

492507
//TODO: add test where inserting and querying with vectors that are all the same.
493508

494509
#[cfg(any(test, feature = "pg_test"))]
495510
pub unsafe fn test_index_creation_and_accuracy_scaffold(
511+
distance_type: DistanceType,
496512
index_options: &str,
497513
name: &str,
498514
) -> spi::Result<()> {
515+
let operator = distance_type.get_operator();
516+
let operator_class = distance_type.get_operator_class();
499517
let table_name = format!("test_data_icaa_{}", name);
500518
Spi::run(&format!(
501519
"CREATE TABLE {table_name} (
@@ -515,7 +533,7 @@ pub mod tests {
515533
GROUP BY
516534
i % 300) g;
517535
518-
CREATE INDEX ON {table_name} USING diskann (embedding) WITH ({index_options});
536+
CREATE INDEX ON {table_name} USING diskann (embedding {operator_class}) WITH ({index_options});
519537
520538
521539
SET enable_seqscan = 0;
@@ -525,7 +543,7 @@ pub mod tests {
525543
FROM
526544
{table_name}
527545
ORDER BY
528-
embedding <=> (
546+
embedding {operator} (
529547
SELECT
530548
('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding
531549
FROM generate_series(1, 1536));"))?;
@@ -542,7 +560,7 @@ pub mod tests {
542560
SET enable_seqscan = 0;
543561
SET enable_indexscan = 1;
544562
SET diskann.query_search_list_size = 2;
545-
WITH cte as (select * from {table_name} order by embedding <=> $1::vector) SELECT count(*) from cte;
563+
WITH cte as (select * from {table_name} order by embedding {operator} $1::vector) SELECT count(*) from cte;
546564
",
547565
),
548566
vec![(
@@ -576,7 +594,7 @@ pub mod tests {
576594
FROM
577595
{table_name}
578596
ORDER BY
579-
embedding <=> (
597+
embedding {operator} (
580598
SELECT
581599
('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding
582600
FROM generate_series(1, 1536));
@@ -608,7 +626,7 @@ pub mod tests {
608626
FROM
609627
{table_name}
610628
ORDER BY
611-
embedding <=> $1::vector
629+
embedding {operator} $1::vector
612630
LIMIT 10
613631
)
614632
SELECT array_agg(ctid) from cte;"
@@ -631,7 +649,7 @@ pub mod tests {
631649
FROM
632650
{table_name}
633651
ORDER BY
634-
embedding <=> $1::vector
652+
embedding {operator} $1::vector
635653
LIMIT 10
636654
)
637655
SELECT array_agg(ctid) from cte;"
@@ -655,7 +673,7 @@ pub mod tests {
655673
FROM
656674
{table_name}
657675
ORDER BY
658-
embedding <=> $1::vector
676+
embedding {operator} $1::vector
659677
LIMIT 10
660678
)
661679
SELECT array_agg(ctid) from cte;"
@@ -686,7 +704,7 @@ pub mod tests {
686704
SET enable_seqscan = 0;
687705
SET enable_indexscan = 1;
688706
SET diskann.query_search_list_size = 2;
689-
WITH cte as (select * from {table_name} order by embedding <=> $1::vector) SELECT count(*) from cte;
707+
WITH cte as (select * from {table_name} order by embedding {operator} $1::vector) SELECT count(*) from cte;
690708
",
691709
),
692710
vec![(
@@ -759,10 +777,14 @@ pub mod tests {
759777

760778
#[cfg(any(test, feature = "pg_test"))]
761779
pub unsafe fn test_index_updates(
780+
distance_type: DistanceType,
762781
index_options: &str,
763782
expected_cnt: i64,
764783
name: &str,
765784
) -> spi::Result<()> {
785+
let operator_class = distance_type.get_operator_class();
786+
let operator = distance_type.get_operator();
787+
766788
let table_name = format!("test_data_index_updates_{}", name);
767789
Spi::run(&format!(
768790
"CREATE TABLE {table_name} (
@@ -784,7 +806,7 @@ pub mod tests {
784806
GROUP BY
785807
i % {expected_cnt}) g;
786808
787-
CREATE INDEX ON {table_name} USING diskann (embedding) WITH ({index_options});
809+
CREATE INDEX ON {table_name} USING diskann (embedding {operator_class}) WITH ({index_options});
788810
789811
790812
SET enable_seqscan = 0;
@@ -794,7 +816,7 @@ pub mod tests {
794816
FROM
795817
{table_name}
796818
ORDER BY
797-
embedding <=> (
819+
embedding {operator} (
798820
SELECT
799821
('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding
800822
FROM generate_series(1, 1536));"))?;
@@ -811,7 +833,7 @@ pub mod tests {
811833
SET enable_seqscan = 0;
812834
SET enable_indexscan = 1;
813835
SET diskann.query_search_list_size = 2;
814-
WITH cte as (select * from {table_name} order by embedding <=> $1::vector) SELECT count(*) from cte;
836+
WITH cte as (select * from {table_name} order by embedding {operator} $1::vector) SELECT count(*) from cte;
815837
",
816838
),
817839
vec![(
@@ -850,7 +872,7 @@ pub mod tests {
850872
SET enable_seqscan = 0;
851873
SET enable_indexscan = 1;
852874
SET diskann.query_search_list_size = 2;
853-
WITH cte as (select * from {table_name} order by embedding <=> $1::vector) SELECT count(*) from cte;
875+
WITH cte as (select * from {table_name} order by embedding {operator} $1::vector) SELECT count(*) from cte;
854876
",
855877
),
856878
vec![(

pgvectorscale/src/access_method/distance.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,54 @@
1+
use pgrx::pg_extern;
2+
13
pub type DistanceFn = fn(&[f32], &[f32]) -> f32;
24

5+
#[derive(Debug, PartialEq)]
6+
pub enum DistanceType {
7+
Cosine = 0,
8+
L2 = 1,
9+
}
10+
11+
impl DistanceType {
12+
pub fn from_u16(value: u16) -> Self {
13+
match value {
14+
0 => DistanceType::Cosine,
15+
1 => DistanceType::L2,
16+
_ => panic!("Unknown DistanceType number {}", value),
17+
}
18+
}
19+
20+
pub fn get_operator(&self) -> &str {
21+
match self {
22+
DistanceType::Cosine => "<=>",
23+
DistanceType::L2 => "<->",
24+
}
25+
}
26+
27+
pub fn get_operator_class(&self) -> &str {
28+
match self {
29+
DistanceType::Cosine => "vector_cosine_ops",
30+
DistanceType::L2 => "vector_l2_ops",
31+
}
32+
}
33+
34+
pub fn get_distance_function(&self) -> DistanceFn {
35+
match self {
36+
DistanceType::Cosine => distance_cosine,
37+
DistanceType::L2 => distance_l2,
38+
}
39+
}
40+
}
41+
42+
#[pg_extern(immutable, parallel_safe)]
43+
pub fn distance_type_cosine() -> i16 {
44+
DistanceType::Cosine as i16
45+
}
46+
47+
#[pg_extern(immutable, parallel_safe)]
48+
pub fn distance_type_l2() -> i16 {
49+
DistanceType::L2 as i16
50+
}
51+
352
/* we use the avx2 version of x86 functions. This verifies that's kosher */
453
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
554
#[cfg(not(target_feature = "avx2"))]

pgvectorscale/src/access_method/meta_page.rs

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::access_method::options::TSVIndexOptions;
88
use crate::util::page;
99
use crate::util::*;
1010

11-
use super::distance::{self, DistanceFn};
11+
use super::distance::{DistanceFn, DistanceType};
1212
use super::options::{
1313
NUM_DIMENSIONS_DEFAULT_SENTINEL, NUM_NEIGHBORS_DEFAULT_SENTINEL,
1414
SBQ_NUM_BITS_PER_DIMENSION_DEFAULT_SENTINEL,
@@ -93,22 +93,6 @@ pub struct MetaPageHeader {
9393
version: u32,
9494
}
9595

96-
/// TODO: move to distance.rs
97-
enum DistanceType {
98-
Cosine = 0,
99-
L2 = 1,
100-
}
101-
102-
impl DistanceType {
103-
fn from_u16(value: u16) -> Self {
104-
match value {
105-
0 => DistanceType::Cosine,
106-
1 => DistanceType::L2,
107-
_ => panic!("Unknown DistanceType number {}", value),
108-
}
109-
}
110-
}
111-
11296
/// This is metadata about the entire index.
11397
/// Stored as the first page (offset 2) in the index relation.
11498
#[derive(Clone, PartialEq, Archive, Deserialize, Serialize, Readable, Writeable)]
@@ -175,10 +159,7 @@ impl MetaPage {
175159
}
176160

177161
pub fn get_distance_function(&self) -> DistanceFn {
178-
match DistanceType::from_u16(self.distance_type) {
179-
DistanceType::Cosine => distance::distance_cosine,
180-
DistanceType::L2 => distance::distance_l2,
181-
}
162+
DistanceType::from_u16(self.distance_type).get_distance_function()
182163
}
183164

184165
pub fn get_storage_type(&self) -> StorageType {
@@ -234,6 +215,7 @@ impl MetaPage {
234215
pub unsafe fn create(
235216
index: &PgRelation,
236217
num_dimensions: u32,
218+
distance_type: DistanceType,
237219
opt: PgBox<TSVIndexOptions>,
238220
) -> MetaPage {
239221
let version = Version::parse(env!("CARGO_PKG_VERSION")).unwrap();
@@ -272,7 +254,7 @@ impl MetaPage {
272254
magic_number: TSV_MAGIC_NUMBER,
273255
version: TSV_VERSION,
274256
extension_version_when_built: version.to_string(),
275-
distance_type: DistanceType::Cosine as u16,
257+
distance_type: distance_type as u16,
276258
num_dimensions,
277259
num_dimensions_to_index,
278260
storage_type: (*opt).get_storage_type() as u8,

0 commit comments

Comments
 (0)