Skip to content

Commit 1a9e0b6

Browse files
cutecutecatusamoi
andauthored
feat: new text representation for sparse vector (#466)
* feat: new text embedding for sparse vector Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix: use 0-based index Signed-off-by: usamoi <usamoi@outlook.com> * refactor: use sparse struct to parse Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix: zero-check, sort and tests Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix: new reject case Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * refactor: use state machine Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix: fsm with more checks Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix: by comments Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix: by comments Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> * fix: remove funcs Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> --------- Signed-off-by: cutecutecat <junyuchen@tensorchord.ai> Signed-off-by: usamoi <usamoi@outlook.com> Co-authored-by: usamoi <usamoi@outlook.com>
1 parent d192c40 commit 1a9e0b6

File tree

5 files changed

+325
-93
lines changed

5 files changed

+325
-93
lines changed

src/datatype/text_svecf32.rs

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,97 @@
1+
use num_traits::Zero;
2+
use pgrx::error;
3+
14
use super::memory_svecf32::SVecf32Output;
25
use crate::datatype::memory_svecf32::SVecf32Input;
3-
use crate::datatype::typmod::Typmod;
46
use crate::error::*;
57
use base::scalar::*;
68
use base::vector::*;
7-
use num_traits::Zero;
89
use pgrx::pg_sys::Oid;
910
use std::ffi::{CStr, CString};
11+
use std::fmt::Write;
1012

1113
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
12-
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output {
13-
use crate::utils::parse::parse_vector;
14-
let reserve = Typmod::parse_from_i32(typmod)
15-
.unwrap()
16-
.dims()
17-
.map(|x| x.get())
18-
.unwrap_or(0);
19-
let v = parse_vector(input.to_bytes(), reserve as usize, |s| {
20-
s.parse::<F32>().ok()
21-
});
14+
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output {
15+
use crate::utils::parse::parse_pgvector_svector;
16+
let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::<F32>().ok());
2217
match v {
2318
Err(e) => {
2419
bad_literal(&e.to_string());
2520
}
26-
Ok(vector) => {
27-
check_value_dims_1048575(vector.len());
28-
let mut indexes = Vec::<u32>::new();
29-
let mut values = Vec::<F32>::new();
30-
for (i, &x) in vector.iter().enumerate() {
31-
if !x.is_zero() {
32-
indexes.push(i as u32);
33-
values.push(x);
21+
Ok((mut indexes, mut values, dims)) => {
22+
check_value_dims_1048575(dims);
23+
// is_sorted
24+
if !indexes.windows(2).all(|i| i[0] <= i[1]) {
25+
assert_eq!(indexes.len(), values.len());
26+
let n = indexes.len();
27+
let mut permutation = (0..n).collect::<Vec<_>>();
28+
permutation.sort_unstable_by_key(|&i| &indexes[i]);
29+
for i in 0..n {
30+
if i == permutation[i] || usize::MAX == permutation[i] {
31+
continue;
32+
}
33+
let index = indexes[i];
34+
let value = values[i];
35+
let mut j = i;
36+
while i != permutation[j] {
37+
let next = permutation[j];
38+
indexes[j] = indexes[permutation[j]];
39+
values[j] = values[permutation[j]];
40+
permutation[j] = usize::MAX;
41+
j = next;
42+
}
43+
indexes[j] = index;
44+
values[j] = value;
45+
permutation[j] = usize::MAX;
46+
}
47+
}
48+
let mut last: Option<u32> = None;
49+
for index in indexes.clone() {
50+
if last == Some(index) {
51+
error!(
52+
"Indexes need to be unique, but there are more than one same index {index}"
53+
)
54+
}
55+
if last >= Some(dims as u32) {
56+
error!("Index out of bounds: the dim is {dims} but the index is {index}");
57+
}
58+
last = Some(index);
59+
{
60+
let mut i = 0;
61+
let mut j = 0;
62+
while j < values.len() {
63+
if !values[j].is_zero() {
64+
indexes[i] = indexes[j];
65+
values[i] = values[j];
66+
i += 1;
67+
}
68+
j += 1;
69+
}
70+
indexes.truncate(i);
71+
values.truncate(i);
3472
}
3573
}
36-
SVecf32Output::new(SVecf32Borrowed::new(vector.len() as u32, &indexes, &values))
74+
SVecf32Output::new(SVecf32Borrowed::new(dims as u32, &indexes, &values))
3775
}
3876
}
3977
}
4078

4179
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
4280
fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString {
81+
let dims = vector.for_borrow().dims();
4382
let mut buffer = String::new();
44-
buffer.push('[');
45-
let vec = vector.for_borrow().to_vec();
46-
let mut iter = vec.iter();
47-
if let Some(x) = iter.next() {
48-
buffer.push_str(format!("{}", x).as_str());
49-
}
50-
for x in iter {
51-
buffer.push_str(format!(", {}", x).as_str());
83+
buffer.push('{');
84+
let svec = vector.for_borrow();
85+
let mut need_splitter = false;
86+
for (&index, &value) in svec.indexes().iter().zip(svec.values().iter()) {
87+
match need_splitter {
88+
false => {
89+
write!(buffer, "{}:{}", index, value).unwrap();
90+
need_splitter = true;
91+
}
92+
true => write!(buffer, ", {}:{}", index, value).unwrap(),
93+
}
5294
}
53-
buffer.push(']');
95+
write!(buffer, "}}/{}", dims).unwrap();
5496
CString::new(buffer).unwrap()
5597
}

src/utils/parse.rs

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
use num_traits::Zero;
12
use thiserror::Error;
23

3-
#[derive(Debug, Error)]
4+
#[derive(Debug, Error, PartialEq)]
45
pub enum ParseVectorError {
56
#[error("The input string is empty.")]
67
EmptyString {},
@@ -83,3 +84,195 @@ where
8384
}
8485
Ok(vector)
8586
}
87+
88+
#[derive(PartialEq, Debug, Clone)]
89+
enum ParseState {
90+
Start,
91+
LeftBracket,
92+
Index,
93+
Colon,
94+
Value,
95+
Comma,
96+
RightBracket,
97+
Splitter,
98+
Dims,
99+
}
100+
101+
#[inline(always)]
102+
pub fn parse_pgvector_svector<T: Zero + Clone, F>(
103+
input: &[u8],
104+
f: F,
105+
) -> Result<(Vec<u32>, Vec<T>, usize), ParseVectorError>
106+
where
107+
F: Fn(&str) -> Option<T>,
108+
{
109+
use arrayvec::ArrayVec;
110+
if input.is_empty() {
111+
return Err(ParseVectorError::EmptyString {});
112+
}
113+
let mut token: ArrayVec<u8, 48> = ArrayVec::new();
114+
let mut indexes = Vec::<u32>::new();
115+
let mut values = Vec::<T>::new();
116+
117+
let mut state = ParseState::Start;
118+
for (position, c) in input.iter().copied().enumerate() {
119+
state = match (&state, c) {
120+
(_, b' ') => state,
121+
(ParseState::Start, b'{') => ParseState::LeftBracket,
122+
(
123+
ParseState::LeftBracket | ParseState::Index | ParseState::Comma,
124+
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-',
125+
) => {
126+
if token.try_push(c).is_err() {
127+
return Err(ParseVectorError::TooLongNumber { position });
128+
}
129+
ParseState::Index
130+
}
131+
(ParseState::Colon, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => {
132+
if token.try_push(c).is_err() {
133+
return Err(ParseVectorError::TooLongNumber { position });
134+
}
135+
ParseState::Value
136+
}
137+
(ParseState::LeftBracket | ParseState::Comma, b'}') => ParseState::RightBracket,
138+
(ParseState::Index, b':') => {
139+
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
140+
let index = s
141+
.parse::<u32>()
142+
.map_err(|_| ParseVectorError::BadParsing { position })?;
143+
indexes.push(index);
144+
token.clear();
145+
ParseState::Colon
146+
}
147+
(ParseState::Value, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => {
148+
if token.try_push(c).is_err() {
149+
return Err(ParseVectorError::TooLongNumber { position });
150+
}
151+
ParseState::Value
152+
}
153+
(ParseState::Value, b',') => {
154+
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
155+
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
156+
values.push(num);
157+
token.clear();
158+
ParseState::Comma
159+
}
160+
(ParseState::Value, b'}') => {
161+
if token.is_empty() {
162+
return Err(ParseVectorError::TooShortNumber { position });
163+
}
164+
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
165+
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
166+
values.push(num);
167+
token.clear();
168+
ParseState::RightBracket
169+
}
170+
(ParseState::RightBracket, b'/') => ParseState::Splitter,
171+
(ParseState::Dims | ParseState::Splitter, b'0'..=b'9') => {
172+
if token.try_push(c).is_err() {
173+
return Err(ParseVectorError::TooLongNumber { position });
174+
}
175+
ParseState::Dims
176+
}
177+
(_, _) => {
178+
return Err(ParseVectorError::BadCharacter { position });
179+
}
180+
}
181+
}
182+
if state != ParseState::Dims {
183+
return Err(ParseVectorError::BadParsing {
184+
position: input.len(),
185+
});
186+
}
187+
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
188+
let dims = s
189+
.parse::<usize>()
190+
.map_err(|_| ParseVectorError::BadParsing {
191+
position: input.len(),
192+
})?;
193+
Ok((indexes, values, dims))
194+
}
195+
196+
#[cfg(test)]
197+
mod tests {
198+
use base::scalar::F32;
199+
200+
use super::*;
201+
202+
#[test]
203+
fn test_svector_parse_accept() {
204+
let exprs: Vec<(&str, (Vec<u32>, Vec<F32>, usize))> = vec![
205+
("{}/1", (vec![], vec![], 1)),
206+
("{0:1}/1", (vec![0], vec![F32(1.0)], 1)),
207+
(
208+
"{0:1, 1:-2, }/2",
209+
(vec![0, 1], vec![F32(1.0), F32(-2.0)], 2),
210+
),
211+
("{0:1, 1:1.5}/2", (vec![0, 1], vec![F32(1.0), F32(1.5)], 2)),
212+
(
213+
"{0:+3, 2:-4.1}/3",
214+
(vec![0, 2], vec![F32(3.0), F32(-4.1)], 3),
215+
),
216+
(
217+
"{0:0, 1:0, 2:0}/3",
218+
(vec![0, 1, 2], vec![F32(0.0), F32(0.0), F32(0.0)], 3),
219+
),
220+
(
221+
"{3:3, 2:2, 1:1, 0:0}/4",
222+
(
223+
vec![3, 2, 1, 0],
224+
vec![F32(3.0), F32(2.0), F32(1.0), F32(0.0)],
225+
4,
226+
),
227+
),
228+
];
229+
for (e, parsed) in exprs {
230+
let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok());
231+
assert!(ret.is_ok(), "at expr {:?}: {:?}", e, ret);
232+
assert_eq!(ret.unwrap(), parsed, "parsed at expr {:?}", e);
233+
}
234+
}
235+
236+
#[test]
237+
fn test_svector_parse_reject() {
238+
let exprs: Vec<(&str, ParseVectorError)> = vec![
239+
("{", ParseVectorError::BadParsing { position: 1 }),
240+
("}", ParseVectorError::BadCharacter { position: 0 }),
241+
("{:", ParseVectorError::BadCharacter { position: 1 }),
242+
(":}", ParseVectorError::BadCharacter { position: 0 }),
243+
(
244+
"{0:1, 1:2, 2:3}",
245+
ParseVectorError::BadParsing { position: 15 },
246+
),
247+
(
248+
"{0:1, 1:2, 2:3",
249+
ParseVectorError::BadParsing { position: 14 },
250+
),
251+
("{0:1, 1:2}/", ParseVectorError::BadParsing { position: 11 }),
252+
("{0}/5", ParseVectorError::BadCharacter { position: 2 }),
253+
("{0:}/5", ParseVectorError::BadCharacter { position: 3 }),
254+
("{:0}/5", ParseVectorError::BadCharacter { position: 1 }),
255+
(
256+
"{0:, 1:2}/5",
257+
ParseVectorError::BadCharacter { position: 3 },
258+
),
259+
("{0:1, 1}/5", ParseVectorError::BadCharacter { position: 7 }),
260+
("/2", ParseVectorError::BadCharacter { position: 0 }),
261+
("{}/1/2", ParseVectorError::BadCharacter { position: 4 }),
262+
(
263+
"{0:1, 1:2}/4/2",
264+
ParseVectorError::BadCharacter { position: 12 },
265+
),
266+
("{}/-4", ParseVectorError::BadCharacter { position: 3 }),
267+
(
268+
"{1,2,3,4}/5",
269+
ParseVectorError::BadCharacter { position: 2 },
270+
),
271+
];
272+
for (e, err) in exprs {
273+
let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok());
274+
assert!(ret.is_err(), "at expr {:?}: {:?}", e, ret);
275+
assert_eq!(ret.unwrap_err(), err, "parsed at expr {:?}", e);
276+
}
277+
}
278+
}

tests/sqllogictest/sparse.slt

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ CREATE INDEX ON t USING vectors (val svector_cos_ops)
2020
WITH (options = "[indexing.hnsw]");
2121

2222
query I
23-
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
23+
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '{1:3,2:1}/6'::svector limit 10) t2;
2424
----
2525
10
2626

2727
query I
28-
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
28+
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '{1:3,2:1}/6'::svector limit 10) t2;
2929
----
3030
10
3131

3232
query I
33-
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
33+
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '{1:3,2:1}/6'::svector limit 10) t2;
3434
----
3535
10
3636

@@ -40,21 +40,18 @@ DROP TABLE t;
4040
query I
4141
SELECT to_svector(5, '{1,2}', '{1,2}');
4242
----
43-
[0, 1, 2, 0, 0]
43+
{1:1, 2:2}/5
4444

4545
query I
4646
SELECT to_svector(5, '{1,2}', '{1,1}') * to_svector(5, '{1,3}', '{2,2}');
4747
----
48-
[0, 2, 0, 0, 0]
48+
{1:2}/5
4949

5050
statement error Lengths of index and value are not matched.
5151
SELECT to_svector(5, '{1,2,3}', '{1,2}');
5252

5353
statement error Duplicated index.
5454
SELECT to_svector(5, '{1,1}', '{1,2}');
5555

56-
statement ok
57-
SELECT replace(replace(array_agg(RANDOM())::real[]::text, '{', '['), '}', ']')::svector FROM generate_series(1, 100000);
58-
5956
statement ok
6057
SELECT to_svector(200000, array_agg(val)::integer[], array_agg(val)::real[]) FROM generate_series(1, 100000) AS VAL;

0 commit comments

Comments
 (0)