Skip to content

Commit 0c57ad4

Browse files
cutecutecatusamoi
authored andcommitted
feat: new text embedding for sparse vector
Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
1 parent 13e4245 commit 0c57ad4

File tree

4 files changed

+147
-43
lines changed

4 files changed

+147
-43
lines changed

src/datatype/text_svecf32.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ use std::ffi::{CStr, CString};
1010

1111
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
1212
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output {
13-
use crate::utils::parse::parse_vector;
13+
use crate::utils::parse::parse_pgvector_svector;
1414
let reserve = Typmod::parse_from_i32(typmod)
1515
.unwrap()
1616
.dims()
1717
.map(|x| x.get())
1818
.unwrap_or(0);
19-
let v = parse_vector(input.to_bytes(), reserve as usize, |s| {
19+
let v = parse_pgvector_svector(input.to_bytes(), reserve as usize, |s| {
2020
s.parse::<F32>().ok()
2121
});
2222
match v {
@@ -40,16 +40,20 @@ fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output {
4040

4141
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
4242
fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString {
43+
let dims = vector.for_borrow().dims();
4344
let mut buffer = String::new();
44-
buffer.push('[');
45-
let vec = vector.for_borrow().to_vec();
46-
let mut iter = vec.iter();
47-
if let Some(x) = iter.next() {
48-
buffer.push_str(format!("{}", x).as_str());
49-
}
50-
for x in iter {
51-
buffer.push_str(format!(", {}", x).as_str());
45+
buffer.push('{');
46+
let svec = vector.for_borrow();
47+
let mut need_splitter = true;
48+
for (&index, &value) in svec.indexes().iter().zip(svec.values().iter()) {
49+
match need_splitter {
50+
true => {
51+
buffer.push_str(format!("{}:{}", index + 1, value).as_str());
52+
need_splitter = false;
53+
}
54+
false => buffer.push_str(format!(", {}:{}", index + 1, value).as_str()),
55+
}
5256
}
53-
buffer.push(']');
57+
buffer.push_str(format!("}}/{}", dims).as_str());
5458
CString::new(buffer).unwrap()
5559
}

src/utils/parse.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use num_traits::Zero;
12
use thiserror::Error;
23

34
#[derive(Debug, Error)]
@@ -83,3 +84,105 @@ where
8384
}
8485
Ok(vector)
8586
}
87+
88+
#[inline(always)]
89+
pub fn parse_pgvector_svector<T: Zero + Clone, F>(
90+
input: &[u8],
91+
reserve: usize,
92+
f: F,
93+
) -> Result<Vec<T>, ParseVectorError>
94+
where
95+
F: Fn(&str) -> Option<T>,
96+
{
97+
use arrayvec::ArrayVec;
98+
if input.is_empty() {
99+
return Err(ParseVectorError::EmptyString {});
100+
}
101+
let left = 'a: {
102+
for position in 0..input.len() - 1 {
103+
match input[position] {
104+
b'{' => break 'a position,
105+
b' ' => continue,
106+
_ => return Err(ParseVectorError::BadCharacter { position }),
107+
}
108+
}
109+
return Err(ParseVectorError::BadParentheses { character: '{' });
110+
};
111+
let mut token: ArrayVec<u8, 48> = ArrayVec::new();
112+
let mut capacity = reserve;
113+
let right = 'a: {
114+
for position in (1..input.len()).rev() {
115+
match input[position] {
116+
b'0'..=b'9' => {
117+
if token.try_push(input[position]).is_err() {
118+
return Err(ParseVectorError::TooLongNumber { position });
119+
}
120+
}
121+
b'/' => {
122+
token.reverse();
123+
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
124+
capacity = s
125+
.parse::<usize>()
126+
.map_err(|_| ParseVectorError::BadParsing { position })?;
127+
}
128+
b'}' => {
129+
token.clear();
130+
break 'a position;
131+
}
132+
b' ' => continue,
133+
_ => return Err(ParseVectorError::BadCharacter { position }),
134+
}
135+
}
136+
return Err(ParseVectorError::BadParentheses { character: '}' });
137+
};
138+
let mut vector = vec![T::zero(); capacity];
139+
let mut index: usize = 0;
140+
for position in left + 1..right {
141+
let c = input[position];
142+
match c {
143+
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => {
144+
if token.is_empty() {
145+
token.push(b'$');
146+
}
147+
if token.try_push(c).is_err() {
148+
return Err(ParseVectorError::TooLongNumber { position });
149+
}
150+
}
151+
b',' => {
152+
if !token.is_empty() {
153+
// Safety: all bytes in `token` are ascii characters
154+
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
155+
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
156+
vector[index] = num;
157+
token.clear();
158+
} else {
159+
return Err(ParseVectorError::TooShortNumber { position });
160+
}
161+
}
162+
b':' => {
163+
if !token.is_empty() {
164+
// Safety: all bytes in `token` are ascii characters
165+
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
166+
index = s
167+
.parse::<usize>()
168+
.map_err(|_| ParseVectorError::BadParsing { position })?
169+
- 1;
170+
token.clear();
171+
} else {
172+
return Err(ParseVectorError::TooShortNumber { position });
173+
}
174+
}
175+
b' ' => (),
176+
_ => return Err(ParseVectorError::BadCharacter { position }),
177+
}
178+
}
179+
if !token.is_empty() {
180+
let position = right;
181+
// Safety: all bytes in `token` are ascii characters
182+
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
183+
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
184+
vector[index] = num;
185+
token.clear();
186+
}
187+
Ok(vector)
188+
}

tests/sqllogictest/sparse.slt

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ CREATE INDEX ON t USING vectors (val svector_cos_ops)
2020
WITH (options = "[indexing.hnsw]");
2121

2222
query I
23-
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
23+
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '{1:3,2:1}/6'::svector limit 10) t2;
2424
----
2525
10
2626

2727
query I
28-
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
28+
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '{1:3,2:1}/6'::svector limit 10) t2;
2929
----
3030
10
3131

3232
query I
33-
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
33+
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '{1:3,2:1}/6'::svector limit 10) t2;
3434
----
3535
10
3636

@@ -40,7 +40,7 @@ DROP TABLE t;
4040
query I
4141
SELECT to_svector(5, '{1,2}', '{1,2}');
4242
----
43-
[0, 1, 2, 0, 0]
43+
{2:1, 3:2}/5
4444

4545
query I
4646
SELECT to_svector(5, '{1,2}', '{1,1}') * to_svector(5, '{1,3}', '{2,2}');
@@ -53,8 +53,5 @@ SELECT to_svector(5, '{1,2,3}', '{1,2}');
5353
statement error Duplicated index.
5454
SELECT to_svector(5, '{1,1}', '{1,2}');
5555

56-
statement ok
57-
SELECT replace(replace(array_agg(RANDOM())::real[]::text, '{', '['), '}', ']')::svector FROM generate_series(1, 100000);
58-
5956
statement ok
6057
SELECT to_svector(200000, array_agg(val)::integer[], array_agg(val)::real[]) FROM generate_series(1, 100000) AS VAL;

tests/sqllogictest/svector_subscript.slt

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,87 +2,87 @@ statement ok
22
SET search_path TO pg_temp, vectors;
33

44
query I
5-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3:6];
5+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:6];
66
----
7-
[3, 4, 5]
7+
{1:3, 2:4, 3:5}/3
88

99
query I
10-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:4];
10+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:4];
1111
----
12-
[0, 1, 2, 3]
12+
{2:1, 3:2, 4:3}/4
1313

1414
query I
15-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[5:];
15+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:];
1616
----
17-
[5, 6, 7]
17+
{1:5, 2:6, 3:7}/3
1818

1919
query I
20-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[1:8];
20+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:8];
2121
----
22-
[1, 2, 3, 4, 5, 6, 7]
22+
{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/7
2323

2424
statement error type svector does only support one subscript
25-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3:3][1:1];
25+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:3][1:1];
2626

2727
statement error type svector does only support slice fetch
28-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3];
28+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3];
2929

3030
query I
31-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[5:4];
31+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:4];
3232
----
3333
NULL
3434

3535
query I
36-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[9:];
36+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[9:];
3737
----
3838
NULL
3939

4040
query I
41-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:0];
41+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:0];
4242
----
4343
NULL
4444

4545
query I
46-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:-1];
46+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:-1];
4747
----
4848
NULL
4949

5050
query I
51-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:NULL];
51+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:NULL];
5252
----
5353
NULL
5454

5555
query I
56-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:8];
56+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:8];
5757
----
5858
NULL
5959

6060
query I
61-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[1:NULL];
61+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:NULL];
6262
----
6363
NULL
6464

6565
query I
66-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:];
66+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:];
6767
----
6868
NULL
6969

7070
query I
71-
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:NULL];
71+
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:NULL];
7272
----
7373
NULL
7474

7575
query I
76-
SELECT ('[0, 0, 2, 0, 4, 0, 0, 7]'::svector)[3:7];
76+
SELECT ('{3:2, 5:4, 8:7}/8'::svector)[3:7];
7777
----
78-
[0, 4, 0, 0]
78+
{2:4}/4
7979

8080
query I
81-
SELECT ('[0, 0, 2, 0, 4, 0, 0, 7]'::svector)[5:7];
81+
SELECT ('{3:2, 5:4, 8:7}/8'::svector)[5:7];
8282
----
83-
[0, 0]
83+
{}/2
8484

8585
query I
86-
SELECT ('[0, 0, 0, 0, 0, 0, 0, 0]'::svector)[5:7];
86+
SELECT ('{}/8'::svector)[5:7];
8787
----
88-
[0, 0]
88+
{}/2

0 commit comments

Comments
 (0)