From c6719dbf181c17fe42fc46d8d001a6e377cc508a Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Mon, 15 Apr 2024 10:11:01 +0800 Subject: [PATCH 01/10] feat: new text embedding for sparse vector Signed-off-by: cutecutecat --- src/datatype/text_svecf32.rs | 26 +++--- src/utils/parse.rs | 103 +++++++++++++++++++++++ tests/sqllogictest/sparse.slt | 11 +-- tests/sqllogictest/svector_subscript.slt | 50 +++++------ 4 files changed, 147 insertions(+), 43 deletions(-) diff --git a/src/datatype/text_svecf32.rs b/src/datatype/text_svecf32.rs index f2a0d6bb4..36d258221 100644 --- a/src/datatype/text_svecf32.rs +++ b/src/datatype/text_svecf32.rs @@ -10,13 +10,13 @@ use std::ffi::{CStr, CString}; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output { - use crate::utils::parse::parse_vector; + use crate::utils::parse::parse_pgvector_svector; let reserve = Typmod::parse_from_i32(typmod) .unwrap() .dims() .map(|x| x.get()) .unwrap_or(0); - let v = parse_vector(input.to_bytes(), reserve as usize, |s| { + let v = parse_pgvector_svector(input.to_bytes(), reserve as usize, |s| { s.parse::().ok() }); match v { @@ -40,16 +40,20 @@ fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output { #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString { + let dims = vector.for_borrow().dims(); let mut buffer = String::new(); - buffer.push('['); - let vec = vector.for_borrow().to_vec(); - let mut iter = vec.iter(); - if let Some(x) = iter.next() { - buffer.push_str(format!("{}", x).as_str()); - } - for x in iter { - buffer.push_str(format!(", {}", x).as_str()); + buffer.push('{'); + let svec = vector.for_borrow(); + let mut need_splitter = true; + for (&index, &value) in svec.indexes().iter().zip(svec.values().iter()) { + match need_splitter { + true => { + buffer.push_str(format!("{}:{}", index + 1, value).as_str()); + need_splitter = false; + } + false => buffer.push_str(format!(", {}:{}", index + 1, value).as_str()), + } } - buffer.push(']'); + buffer.push_str(format!("}}/{}", dims).as_str()); CString::new(buffer).unwrap() } diff --git a/src/utils/parse.rs b/src/utils/parse.rs index e5ef93316..a3dec2c78 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -1,3 +1,4 @@ +use num_traits::Zero; use thiserror::Error; #[derive(Debug, Error)] @@ -83,3 +84,105 @@ where } Ok(vector) } + +#[inline(always)] +pub fn parse_pgvector_svector( + input: &[u8], + reserve: usize, + f: F, +) -> Result, ParseVectorError> +where + F: Fn(&str) -> Option, +{ + use arrayvec::ArrayVec; + if input.is_empty() { + return Err(ParseVectorError::EmptyString {}); + } + let left = 'a: { + for position in 0..input.len() - 1 { + match input[position] { + b'{' => break 'a position, + b' ' => continue, + _ => return Err(ParseVectorError::BadCharacter { position }), + } + } + return Err(ParseVectorError::BadParentheses { character: '{' }); + }; + let mut token: ArrayVec = ArrayVec::new(); + let mut capacity = reserve; + let right = 'a: { + for position in (1..input.len()).rev() { + match input[position] { + b'0'..=b'9' => { + if token.try_push(input[position]).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); + } + } + b'/' => { + token.reverse(); + let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; + capacity = s + .parse::() + .map_err(|_| ParseVectorError::BadParsing { position })?; + } + b'}' => { + token.clear(); + break 'a position; + } + b' ' => continue, + _ => return Err(ParseVectorError::BadCharacter { position }), + } + } + return Err(ParseVectorError::BadParentheses { character: '}' }); + }; + let mut vector = vec![T::zero(); capacity]; + let mut index: usize = 0; + for position in left + 1..right { + let c = input[position]; + match c { + b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { + if token.is_empty() { + token.push(b'$'); + } + if token.try_push(c).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); + } + } + b',' => { + if !token.is_empty() { + // Safety: all bytes in `token` are ascii characters + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; + vector[index] = num; + token.clear(); + } else { + return Err(ParseVectorError::TooShortNumber { position }); + } + } + b':' => { + if !token.is_empty() { + // Safety: all bytes in `token` are ascii characters + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + index = s + .parse::() + .map_err(|_| ParseVectorError::BadParsing { position })? + - 1; + token.clear(); + } else { + return Err(ParseVectorError::TooShortNumber { position }); + } + } + b' ' => (), + _ => return Err(ParseVectorError::BadCharacter { position }), + } + } + if !token.is_empty() { + let position = right; + // Safety: all bytes in `token` are ascii characters + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; + vector[index] = num; + token.clear(); + } + Ok(vector) +} diff --git a/tests/sqllogictest/sparse.slt b/tests/sqllogictest/sparse.slt index 07513d465..8aa5f6a97 100644 --- a/tests/sqllogictest/sparse.slt +++ b/tests/sqllogictest/sparse.slt @@ -20,17 +20,17 @@ CREATE INDEX ON t USING vectors (val svector_cos_ops) WITH (options = "[indexing.hnsw]"); query I -SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2; +SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '{1:3,2:1}/6'::svector limit 10) t2; ---- 10 query I -SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2; +SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '{1:3,2:1}/6'::svector limit 10) t2; ---- 10 query I -SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2; +SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '{1:3,2:1}/6'::svector limit 10) t2; ---- 10 @@ -40,7 +40,7 @@ DROP TABLE t; query I SELECT to_svector(5, '{1,2}', '{1,2}'); ---- -[0, 1, 2, 0, 0] +{2:1, 3:2}/5 query I SELECT to_svector(5, '{1,2}', '{1,1}') * to_svector(5, '{1,3}', '{2,2}'); @@ -53,8 +53,5 @@ SELECT to_svector(5, '{1,2,3}', '{1,2}'); statement error Duplicated index. SELECT to_svector(5, '{1,1}', '{1,2}'); -statement ok -SELECT replace(replace(array_agg(RANDOM())::real[]::text, '{', '['), '}', ']')::svector FROM generate_series(1, 100000); - statement ok SELECT to_svector(200000, array_agg(val)::integer[], array_agg(val)::real[]) FROM generate_series(1, 100000) AS VAL; diff --git a/tests/sqllogictest/svector_subscript.slt b/tests/sqllogictest/svector_subscript.slt index ad683b75a..e657199fa 100644 --- a/tests/sqllogictest/svector_subscript.slt +++ b/tests/sqllogictest/svector_subscript.slt @@ -2,87 +2,87 @@ statement ok SET search_path TO pg_temp, vectors; query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3:6]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:6]; ---- -[3, 4, 5] +{1:3, 2:4, 3:5}/3 query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:4]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:4]; ---- -[0, 1, 2, 3] +{2:1, 3:2, 4:3}/4 query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[5:]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:]; ---- -[5, 6, 7] +{1:5, 2:6, 3:7}/3 query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[1:8]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:8]; ---- -[1, 2, 3, 4, 5, 6, 7] +{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/7 statement error type svector does only support one subscript -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3:3][1:1]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:3][1:1]; statement error type svector does only support slice fetch -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3]; query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[5:4]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:4]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[9:]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[9:]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:0]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:0]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:-1]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:-1]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:NULL]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:NULL]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:8]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:8]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[1:NULL]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:NULL]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:NULL]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:NULL]; ---- NULL query I -SELECT ('[0, 0, 2, 0, 4, 0, 0, 7]'::svector)[3:7]; +SELECT ('{3:2, 5:4, 8:7}/8'::svector)[3:7]; ---- -[0, 4, 0, 0] +{2:4}/4 query I -SELECT ('[0, 0, 2, 0, 4, 0, 0, 7]'::svector)[5:7]; +SELECT ('{3:2, 5:4, 8:7}/8'::svector)[5:7]; ---- -[0, 0] +{}/2 query I -SELECT ('[0, 0, 0, 0, 0, 0, 0, 0]'::svector)[5:7]; +SELECT ('{}/8'::svector)[5:7]; ---- -[0, 0] \ No newline at end of file +{}/2 \ No newline at end of file From 4b2d567c20905dba1e6cd16eb321702c1ebae58b Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 28 May 2024 20:34:12 +0800 Subject: [PATCH 02/10] fix: use 0-based index Signed-off-by: usamoi --- src/datatype/text_svecf32.rs | 4 ++-- src/utils/parse.rs | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/datatype/text_svecf32.rs b/src/datatype/text_svecf32.rs index 36d258221..cb5e00815 100644 --- a/src/datatype/text_svecf32.rs +++ b/src/datatype/text_svecf32.rs @@ -48,10 +48,10 @@ fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString { for (&index, &value) in svec.indexes().iter().zip(svec.values().iter()) { match need_splitter { true => { - buffer.push_str(format!("{}:{}", index + 1, value).as_str()); + buffer.push_str(format!("{}:{}", index, value).as_str()); need_splitter = false; } - false => buffer.push_str(format!(", {}:{}", index + 1, value).as_str()), + false => buffer.push_str(format!(", {}:{}", index, value).as_str()), } } buffer.push_str(format!("}}/{}", dims).as_str()); diff --git a/src/utils/parse.rs b/src/utils/parse.rs index a3dec2c78..e6eeda82b 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -165,8 +165,7 @@ where let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; index = s .parse::() - .map_err(|_| ParseVectorError::BadParsing { position })? - - 1; + .map_err(|_| ParseVectorError::BadParsing { position })?; token.clear(); } else { return Err(ParseVectorError::TooShortNumber { position }); From 5649ac8cf8a12c92b538dd9a696508db00e2ce3b Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Mon, 3 Jun 2024 15:42:06 +0800 Subject: [PATCH 03/10] refactor: use sparse struct to parse Signed-off-by: cutecutecat --- src/datatype/text_svecf32.rs | 27 ++--------- src/utils/parse.rs | 22 +++++---- tests/sqllogictest/sparse.slt | 4 +- tests/sqllogictest/svector.slt | 58 ++++++++++++------------ tests/sqllogictest/svector_subscript.slt | 44 +++++++++--------- 5 files changed, 70 insertions(+), 85 deletions(-) diff --git a/src/datatype/text_svecf32.rs b/src/datatype/text_svecf32.rs index cb5e00815..e32bddd83 100644 --- a/src/datatype/text_svecf32.rs +++ b/src/datatype/text_svecf32.rs @@ -1,39 +1,22 @@ use super::memory_svecf32::SVecf32Output; use crate::datatype::memory_svecf32::SVecf32Input; -use crate::datatype::typmod::Typmod; use crate::error::*; use base::scalar::*; use base::vector::*; -use num_traits::Zero; use pgrx::pg_sys::Oid; use std::ffi::{CStr, CString}; #[pgrx::pg_extern(immutable, strict, parallel_safe)] -fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output { +fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output { use crate::utils::parse::parse_pgvector_svector; - let reserve = Typmod::parse_from_i32(typmod) - .unwrap() - .dims() - .map(|x| x.get()) - .unwrap_or(0); - let v = parse_pgvector_svector(input.to_bytes(), reserve as usize, |s| { - s.parse::().ok() - }); + let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::().ok()); match v { Err(e) => { bad_literal(&e.to_string()); } - Ok(vector) => { - check_value_dims_1048575(vector.len()); - let mut indexes = Vec::::new(); - let mut values = Vec::::new(); - for (i, &x) in vector.iter().enumerate() { - if !x.is_zero() { - indexes.push(i as u32); - values.push(x); - } - } - SVecf32Output::new(SVecf32Borrowed::new(vector.len() as u32, &indexes, &values)) + Ok((indexes, values, dims)) => { + check_value_dims_1048575(dims); + SVecf32Output::new(SVecf32Borrowed::new(dims as u32, &indexes, &values)) } } } diff --git a/src/utils/parse.rs b/src/utils/parse.rs index e6eeda82b..8aa732ba1 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -88,9 +88,8 @@ where #[inline(always)] pub fn parse_pgvector_svector( input: &[u8], - reserve: usize, f: F, -) -> Result, ParseVectorError> +) -> Result<(Vec, Vec, usize), ParseVectorError> where F: Fn(&str) -> Option, { @@ -98,6 +97,7 @@ where if input.is_empty() { return Err(ParseVectorError::EmptyString {}); } + let mut dims: usize = 0; let left = 'a: { for position in 0..input.len() - 1 { match input[position] { @@ -109,7 +109,6 @@ where return Err(ParseVectorError::BadParentheses { character: '{' }); }; let mut token: ArrayVec = ArrayVec::new(); - let mut capacity = reserve; let right = 'a: { for position in (1..input.len()).rev() { match input[position] { @@ -121,7 +120,7 @@ where b'/' => { token.reverse(); let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; - capacity = s + dims = s .parse::() .map_err(|_| ParseVectorError::BadParsing { position })?; } @@ -135,8 +134,9 @@ where } return Err(ParseVectorError::BadParentheses { character: '}' }); }; - let mut vector = vec![T::zero(); capacity]; - let mut index: usize = 0; + let mut indexes = Vec::::new(); + let mut values = Vec::::new(); + let mut index: u32 = 0; for position in left + 1..right { let c = input[position]; match c { @@ -153,7 +153,8 @@ where // Safety: all bytes in `token` are ascii characters let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; - vector[index] = num; + indexes.push(index); + values.push(num); token.clear(); } else { return Err(ParseVectorError::TooShortNumber { position }); @@ -164,7 +165,7 @@ where // Safety: all bytes in `token` are ascii characters let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; index = s - .parse::() + .parse::() .map_err(|_| ParseVectorError::BadParsing { position })?; token.clear(); } else { @@ -180,8 +181,9 @@ where // Safety: all bytes in `token` are ascii characters let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; - vector[index] = num; + indexes.push(index); + values.push(num); token.clear(); } - Ok(vector) + Ok((indexes, values, dims)) } diff --git a/tests/sqllogictest/sparse.slt b/tests/sqllogictest/sparse.slt index 8aa5f6a97..fafededc8 100644 --- a/tests/sqllogictest/sparse.slt +++ b/tests/sqllogictest/sparse.slt @@ -40,12 +40,12 @@ DROP TABLE t; query I SELECT to_svector(5, '{1,2}', '{1,2}'); ---- -{2:1, 3:2}/5 +{1:1, 2:2}/5 query I SELECT to_svector(5, '{1,2}', '{1,1}') * to_svector(5, '{1,3}', '{2,2}'); ---- -[0, 2, 0, 0, 0] +{1:2}/5 statement error Lengths of index and value are not matched. SELECT to_svector(5, '{1,2,3}', '{1,2}'); diff --git a/tests/sqllogictest/svector.slt b/tests/sqllogictest/svector.slt index 9b7ab8661..14cfe4bfe 100644 --- a/tests/sqllogictest/svector.slt +++ b/tests/sqllogictest/svector.slt @@ -6,7 +6,7 @@ CREATE TABLE t (id bigserial, val svector); statement ok INSERT INTO t (val) -VALUES ('[1,2,3]'), ('[4,5,6]'); +VALUES ('{0:1, 1:2, 2:3}/3'), ('{0:4, 1:5, 2:6}/3'); query I SELECT vector_dims(val) FROM t; @@ -23,12 +23,12 @@ SELECT round(vector_norm(val)::numeric, 5) FROM t; query ? SELECT avg(val) FROM t; ---- -[2.5, 3.5, 4.5] +{0:2.5, 1:3.5, 2:4.5}/3 query ? SELECT sum(val) FROM t; ---- -[5, 7, 9] +{0:5, 1:7, 2:9}/3 statement ok CREATE TABLE test_vectors (id serial, data vector(1000)); @@ -46,35 +46,35 @@ SELECT count(*) FROM test_vectors; 5000 query R -SELECT vector_norm('[3,4]'::svector); +SELECT vector_norm('{0:3, 1:4}/2'::svector); ---- 5 query I -SELECT vector_dims(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v; +SELECT vector_dims(v) FROM unnest(ARRAY['{0:1, 1:2}/2'::svector, '{0:3}/1'::svector]) v; ---- 2 1 query ? -SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]']) v; +SELECT avg(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:3, 1:5, 2:7}/3'::svector]) v; ---- -[2, 3.5, 5] +{0:2, 1:3.5, 2:5}/3 query ? -SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[-1,2,-3]']) v; +SELECT avg(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:-1, 1:2, 2:-3}/3'::svector]) v; ---- -[0, 2, 0] +{1:2}/3 query ? -SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]', NULL]) v; +SELECT avg(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:3, 1:5, 2:7}/3'::svector, NULL]) v; ---- -[2, 3.5, 5] +{0:2, 1:3.5, 2:5}/3 query ? -SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector,NULL]) v; +SELECT avg(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector,NULL]) v; ---- -[1, 2, 3] +{0:1, 1:2, 2:3}/3 query ? SELECT avg(v) FROM unnest(ARRAY[]::svector[]) v; @@ -87,22 +87,22 @@ SELECT avg(v) FROM unnest(ARRAY[NULL]::svector[]) v; NULL query ? -SELECT avg(v) FROM unnest(ARRAY['[3e38]'::svector, '[3e38]']) v; +SELECT avg(v) FROM unnest(ARRAY['{0:3e38}/1'::svector, '{0:3e38}/1'::svector]) v; ---- -[inf] +{0:inf}/1 statement error differs in dimensions -SELECT avg(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v; +SELECT avg(v) FROM unnest(ARRAY['{0:1, 1:2}/2'::svector, '{0:3}/1'::svector]) v; query ? SELECT avg(v) FROM unnest(ARRAY[to_svector(5, '{0,1}', '{2,3}'), to_svector(5, '{0,2}', '{1,3}'), to_svector(5, '{3,4}', '{3,3}')]) v; ---- -[1, 1, 1, 1, 1] +{0:1, 1:1, 2:1, 3:1, 4:1}/5 query ? SELECT avg(v) FROM unnest(ARRAY[to_svector(32, '{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}') ]) v; ---- -[0.33333334, 0.6666667, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.6666667, 0.33333334, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +{0:0.33333334, 1:0.6666667, 2:1, 3:1, 4:1, 5:1, 6:1, 7:1, 8:1, 9:1, 10:1, 11:1, 12:1, 13:1, 14:1, 15:1, 16:0.6666667, 17:0.33333334}/32 # test avg(svector) get the same result as avg(vector) query ? @@ -111,20 +111,20 @@ SELECT avg(data) = avg(data::svector)::vector FROM test_vectors; t query ? -SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]']) v; +SELECT sum(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:3, 1:5, 2:7}/3'::svector]) v; ---- -[4, 7, 10] +{0:4, 1:7, 2:10}/3 # test zero element query ? -SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[-1,2,-3]']) v; +SELECT sum(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:-1, 1:2, 2:-3}/3'::svector]) v; ---- -[0, 4, 0] +{1:4}/3 query ? -SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]', NULL]) v; +SELECT sum(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:3, 1:5, 2:7}/3'::svector, NULL]) v; ---- -[4, 7, 10] +{0:4, 1:7, 2:10}/3 query ? SELECT sum(v) FROM unnest(ARRAY[]::svector[]) v; @@ -137,23 +137,23 @@ SELECT sum(v) FROM unnest(ARRAY[NULL]::svector[]) v; NULL statement error differs in dimensions -SELECT sum(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v; +SELECT sum(v) FROM unnest(ARRAY['{0:1, 1:2}/2'::svector, '{0:3}/1'::svector]) v; # should this return an error ? query ? -SELECT sum(v) FROM unnest(ARRAY['[3e38]'::svector, '[3e38]']) v; +SELECT sum(v) FROM unnest(ARRAY['{0:3e38}/1'::svector, '{0:3e38}/1'::svector]) v; ---- -[inf] +{0:inf}/1 query ? SELECT sum(v) FROM unnest(ARRAY[to_svector(5, '{0,1}', '{1,2}'), to_svector(5, '{0,2}', '{1,2}'), to_svector(5, '{3,4}', '{3,3}')]) v; ---- -[2, 2, 2, 3, 3] +{0:2, 1:2, 2:2, 3:3, 4:3}/5 query ? SELECT sum(v) FROM unnest(ARRAY[to_svector(32, '{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}') ]) v; ---- -[1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +{0:1, 1:2, 2:3, 3:3, 4:3, 5:3, 6:3, 7:3, 8:3, 9:3, 10:3, 11:3, 12:3, 13:3, 14:3, 15:3, 16:2, 17:1}/32 # test sum(svector) get the same result as sum(vector) query ? diff --git a/tests/sqllogictest/svector_subscript.slt b/tests/sqllogictest/svector_subscript.slt index e657199fa..d01fe0af2 100644 --- a/tests/sqllogictest/svector_subscript.slt +++ b/tests/sqllogictest/svector_subscript.slt @@ -2,83 +2,83 @@ statement ok SET search_path TO pg_temp, vectors; query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:6]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[3:6]; ---- -{1:3, 2:4, 3:5}/3 +{0:3, 1:4, 2:5}/3 query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:4]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:4]; ---- -{2:1, 3:2, 4:3}/4 +{1:1, 2:2, 3:3}/4 query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[5:]; ---- -{1:5, 2:6, 3:7}/3 +{0:5, 1:6, 2:7}/3 query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:8]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[1:8]; ---- -{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/7 +{0:1, 1:2, 2:3, 3:4, 4:5, 5:6, 6:7}/7 statement error type svector does only support one subscript -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:3][1:1]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[3:3][1:1]; statement error type svector does only support slice fetch -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[3]; query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:4]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[5:4]; ---- NULL query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[9:]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[9:]; ---- NULL query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:0]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:0]; ---- NULL query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:-1]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:-1]; ---- NULL query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:NULL]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[NULL:NULL]; ---- NULL query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:8]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[NULL:8]; ---- NULL query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:NULL]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[1:NULL]; ---- NULL query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[NULL:]; ---- NULL query I -SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:NULL]; +SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:NULL]; ---- NULL query I -SELECT ('{3:2, 5:4, 8:7}/8'::svector)[3:7]; +SELECT ('{2:2, 4:4, 7:7}/8'::svector)[3:7]; ---- -{2:4}/4 +{1:4}/4 query I -SELECT ('{3:2, 5:4, 8:7}/8'::svector)[5:7]; +SELECT ('{2:2, 4:4, 7:7}/8'::svector)[5:7]; ---- {}/2 From 4d5bd936d99b8c61b7666f5505ee66fb4acffce2 Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Tue, 4 Jun 2024 16:53:46 +0800 Subject: [PATCH 04/10] fix: zero-check, sort and tests Signed-off-by: cutecutecat --- src/utils/parse.rs | 115 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 6 deletions(-) diff --git a/src/utils/parse.rs b/src/utils/parse.rs index 8aa732ba1..02127960b 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -15,6 +15,8 @@ pub enum ParseVectorError { TooShortNumber { position: usize }, #[error("Bad parsing at position {position}")] BadParsing { position: usize }, + #[error("Index out of bounds: the dim is {dims} but the index is {index}")] + OutOfBound { dims: usize, index: usize }, } #[inline(always)] @@ -85,6 +87,14 @@ where Ok(vector) } +#[derive(PartialEq)] +enum ParseState { + Number, + Comma, + Colon, + Start, +} + #[inline(always)] pub fn parse_pgvector_svector( input: &[u8], @@ -136,7 +146,9 @@ where }; let mut indexes = Vec::::new(); let mut values = Vec::::new(); - let mut index: u32 = 0; + let mut index: u32 = u32::MAX; + let mut state = ParseState::Start; + for position in left + 1..right { let c = input[position]; match c { @@ -147,15 +159,29 @@ where if token.try_push(c).is_err() { return Err(ParseVectorError::TooLongNumber { position }); } + state = ParseState::Number; } b',' => { + if state != ParseState::Number { + return Err(ParseVectorError::BadCharacter { position }); + } if !token.is_empty() { // Safety: all bytes in `token` are ascii characters let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; - indexes.push(index); - values.push(num); + if index as usize >= dims { + return Err(ParseVectorError::OutOfBound { + dims, + index: index as usize, + }); + } + if !num.is_zero() { + indexes.push(index); + values.push(num); + } + index = u32::MAX; token.clear(); + state = ParseState::Comma; } else { return Err(ParseVectorError::TooShortNumber { position }); } @@ -168,6 +194,7 @@ where .parse::() .map_err(|_| ParseVectorError::BadParsing { position })?; token.clear(); + state = ParseState::Colon; } else { return Err(ParseVectorError::TooShortNumber { position }); } @@ -176,14 +203,90 @@ where _ => return Err(ParseVectorError::BadCharacter { position }), } } + if state != ParseState::Start && (state != ParseState::Number || index == u32::MAX) { + return Err(ParseVectorError::BadCharacter { position: right }); + } if !token.is_empty() { let position = right; // Safety: all bytes in `token` are ascii characters let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; - indexes.push(index); - values.push(num); + if index as usize >= dims { + return Err(ParseVectorError::OutOfBound { + dims, + index: index as usize, + }); + } + if !num.is_zero() { + indexes.push(index); + values.push(num); + } token.clear(); } - Ok((indexes, values, dims)) + // sort values and indexes ascend by indexes + let mut indices = (0..indexes.len()).collect::>(); + indices.sort_by_key(|&i| &indexes[i]); + let sortedValues: Vec = indices + .iter() + .map(|i| values.get(*i).unwrap().clone()) + .collect(); + indexes.sort(); + Ok((indexes, sortedValues, dims)) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use base::scalar::F32; + + use super::*; + + #[test] + fn test_svector_parse_accept() { + let exprs: HashMap<&str, (Vec, Vec, usize)> = HashMap::from([ + ("{}/1", (vec![], vec![], 1)), + ("{0:1}/1", (vec![0], vec![F32(1.0)], 1)), + ("{0:1, 1:1.5}/2", (vec![0, 1], vec![F32(1.0), F32(1.5)], 2)), + ( + "{0:+3, 2:-4.1}/3", + (vec![0, 2], vec![F32(3.0), F32(-4.1)], 3), + ), + ("{0:0, 1:0, 2:0}/3", (vec![], vec![], 3)), + ( + "{3:3, 2:2, 1:1, 0:0}/4", + (vec![1, 2, 3], vec![F32(1.0), F32(2.0), F32(3.0)], 4), + ), + ]); + for (e, ans) in exprs { + let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); + assert!(ret.is_ok(), "at expr {e}"); + assert_eq!(ret.unwrap(), ans, "at expr {e}"); + } + } + + #[test] + fn test_svector_parse_reject() { + let exprs: Vec<&str> = vec![ + "{", + "}", + "{:", + ":}", + "{0:1, 1:1.5}/1", + "{0:0, 1:0, 2:0}/2", + "{0:1, 1:2, 2:3}", + "{0:1, 1:2, 2:3", + "{0:1, 1:2}/", + "{0}/5", + "{0:}/5", + "{:0}/5", + "{0:, 1:2}/5", + "{0:1, 1}/5", + "/2", + ]; + for e in exprs { + let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); + assert!(ret.is_err(), "at expr {e}") + } + } } From 6da2c90dd478130f0ff179f59755e77dd53325db Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Wed, 5 Jun 2024 15:35:24 +0800 Subject: [PATCH 05/10] fix: new reject case Signed-off-by: cutecutecat --- src/utils/parse.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/utils/parse.rs b/src/utils/parse.rs index 02127960b..3c4d71269 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -107,7 +107,7 @@ where if input.is_empty() { return Err(ParseVectorError::EmptyString {}); } - let mut dims: usize = 0; + let mut dims: usize = usize::MAX; let left = 'a: { for position in 0..input.len() - 1 { match input[position] { @@ -130,6 +130,10 @@ where b'/' => { token.reverse(); let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; + // two `dims` are found + if dims != usize::MAX { + return Err(ParseVectorError::BadCharacter { position }); + } dims = s .parse::() .map_err(|_| ParseVectorError::BadParsing { position })?; @@ -144,6 +148,12 @@ where } return Err(ParseVectorError::BadParentheses { character: '}' }); }; + // `dims` is not found + if dims == usize::MAX { + return Err(ParseVectorError::BadCharacter { + position: input.len(), + }); + } let mut indexes = Vec::::new(); let mut values = Vec::::new(); let mut index: u32 = u32::MAX; @@ -203,6 +213,9 @@ where _ => return Err(ParseVectorError::BadCharacter { position }), } } + // A valid case is either + // - empty string: "" + // - end with number when a index is extracted:"1:2, 3:4" if state != ParseState::Start && (state != ParseState::Number || index == u32::MAX) { return Err(ParseVectorError::BadCharacter { position: right }); } @@ -283,6 +296,7 @@ mod tests { "{0:, 1:2}/5", "{0:1, 1}/5", "/2", + "{}/1/2", ]; for e in exprs { let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); From 45ab3c4bdb5ba3e0e56878979e8f62a27428244d Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Wed, 5 Jun 2024 17:58:36 +0800 Subject: [PATCH 06/10] refactor: use state machine Signed-off-by: cutecutecat --- src/utils/parse.rs | 184 ++++++++++++++++++++++++++++----------------- 1 file changed, 115 insertions(+), 69 deletions(-) diff --git a/src/utils/parse.rs b/src/utils/parse.rs index 3c4d71269..984c235f7 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -1,7 +1,7 @@ use num_traits::Zero; use thiserror::Error; -#[derive(Debug, Error)] +#[derive(Debug, Error, PartialEq)] pub enum ParseVectorError { #[error("The input string is empty.")] EmptyString {}, @@ -89,12 +89,15 @@ where #[derive(PartialEq)] enum ParseState { - Number, + Start, + Index, + Value, Comma, Colon, - Start, + End, } +// Index -> Colon -> Value -> Comma #[inline(always)] pub fn parse_pgvector_svector( input: &[u8], @@ -157,24 +160,59 @@ where let mut indexes = Vec::::new(); let mut values = Vec::::new(); let mut index: u32 = u32::MAX; - let mut state = ParseState::Start; - for position in left + 1..right { - let c = input[position]; - match c { - b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { - if token.is_empty() { - token.push(b'$'); - } - if token.try_push(c).is_err() { - return Err(ParseVectorError::TooLongNumber { position }); + let mut state = ParseState::Start; + let mut position = left; + loop { + if position == right { + let end_with_number = state == ParseState::Value && !token.is_empty(); + let end_with_comma = state == ParseState::Index && token.is_empty(); + if end_with_number || end_with_comma { + state = ParseState::End; + } else { + return Err(ParseVectorError::BadCharacter { position }); + } + } + match state { + ParseState::Index => { + let c = input[position]; + match c { + b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { + if token.is_empty() { + token.push(b'$'); + } + if token.try_push(c).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); + } + position += 1; + } + b':' => { + state = ParseState::Colon; + } + b' ' => position += 1, + _ => return Err(ParseVectorError::BadCharacter { position }), } - state = ParseState::Number; } - b',' => { - if state != ParseState::Number { - return Err(ParseVectorError::BadCharacter { position }); + ParseState::Value => { + let c = input[position]; + match c { + b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { + if token.is_empty() { + token.push(b'$'); + } + if token.try_push(c).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); + } + position += 1; + } + b',' => { + state = ParseState::Comma; + } + b' ' => position += 1, + _ => return Err(ParseVectorError::BadCharacter { position }), } + } + e @ (ParseState::Comma | ParseState::End) => { if !token.is_empty() { // Safety: all bytes in `token` are ascii characters let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; @@ -191,12 +229,17 @@ where } index = u32::MAX; token.clear(); - state = ParseState::Comma; - } else { + } else if e != ParseState::End { return Err(ParseVectorError::TooShortNumber { position }); } + if e == ParseState::End { + break; + } else { + state = ParseState::Index; + position += 1; + } } - b':' => { + ParseState::Colon => { if !token.is_empty() { // Safety: all bytes in `token` are ascii characters let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; @@ -204,47 +247,26 @@ where .parse::() .map_err(|_| ParseVectorError::BadParsing { position })?; token.clear(); - state = ParseState::Colon; } else { return Err(ParseVectorError::TooShortNumber { position }); } + state = ParseState::Value; + position += 1; + } + ParseState::Start => { + state = ParseState::Index; + position += 1; } - b' ' => (), - _ => return Err(ParseVectorError::BadCharacter { position }), - } - } - // A valid case is either - // - empty string: "" - // - end with number when a index is extracted:"1:2, 3:4" - if state != ParseState::Start && (state != ParseState::Number || index == u32::MAX) { - return Err(ParseVectorError::BadCharacter { position: right }); - } - if !token.is_empty() { - let position = right; - // Safety: all bytes in `token` are ascii characters - let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; - let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; - if index as usize >= dims { - return Err(ParseVectorError::OutOfBound { - dims, - index: index as usize, - }); - } - if !num.is_zero() { - indexes.push(index); - values.push(num); } - token.clear(); } - // sort values and indexes ascend by indexes let mut indices = (0..indexes.len()).collect::>(); indices.sort_by_key(|&i| &indexes[i]); - let sortedValues: Vec = indices + let sorted_values: Vec = indices .iter() .map(|i| values.get(*i).unwrap().clone()) .collect(); indexes.sort(); - Ok((indexes, sortedValues, dims)) + Ok((indexes, sorted_values, dims)) } #[cfg(test)] @@ -260,6 +282,10 @@ mod tests { let exprs: HashMap<&str, (Vec, Vec, usize)> = HashMap::from([ ("{}/1", (vec![], vec![], 1)), ("{0:1}/1", (vec![0], vec![F32(1.0)], 1)), + ( + "{0:1, 1:-2, }/2", + (vec![0, 1], vec![F32(1.0), F32(-2.0)], 2), + ), ("{0:1, 1:1.5}/2", (vec![0, 1], vec![F32(1.0), F32(1.5)], 2)), ( "{0:+3, 2:-4.1}/3", @@ -280,27 +306,47 @@ mod tests { #[test] fn test_svector_parse_reject() { - let exprs: Vec<&str> = vec![ - "{", - "}", - "{:", - ":}", - "{0:1, 1:1.5}/1", - "{0:0, 1:0, 2:0}/2", - "{0:1, 1:2, 2:3}", - "{0:1, 1:2, 2:3", - "{0:1, 1:2}/", - "{0}/5", - "{0:}/5", - "{:0}/5", - "{0:, 1:2}/5", - "{0:1, 1}/5", - "/2", - "{}/1/2", - ]; - for e in exprs { + let exprs: HashMap<&str, ParseVectorError> = HashMap::from([ + ("{", ParseVectorError::BadParentheses { character: '{' }), + ("}", ParseVectorError::BadParentheses { character: '{' }), + ("{:", ParseVectorError::BadCharacter { position: 1 }), + (":}", ParseVectorError::BadCharacter { position: 0 }), + ( + "{0:1, 1:1.5}/1", + ParseVectorError::OutOfBound { dims: 1, index: 1 }, + ), + ( + "{0:0, 1:0, 2:0}/2", + ParseVectorError::OutOfBound { dims: 2, index: 2 }, + ), + ( + "{0:1, 1:2, 2:3}", + ParseVectorError::BadCharacter { position: 15 }, + ), + ( + "{0:1, 1:2, 2:3", + ParseVectorError::BadCharacter { position: 12 }, + ), + ("{0:1, 1:2}/", ParseVectorError::BadParsing { position: 10 }), + ("{0}/5", ParseVectorError::BadCharacter { position: 2 }), + ("{0:}/5", ParseVectorError::BadCharacter { position: 3 }), + ("{:0}/5", ParseVectorError::TooShortNumber { position: 1 }), + ( + "{0:, 1:2}/5", + ParseVectorError::TooShortNumber { position: 3 }, + ), + ("{0:1, 1}/5", ParseVectorError::BadCharacter { position: 7 }), + ("/2", ParseVectorError::BadCharacter { position: 0 }), + ("{}/1/2", ParseVectorError::BadCharacter { position: 2 }), + ( + "{1,2,3,4}/5", + ParseVectorError::BadCharacter { position: 2 }, + ), + ]); + for (e, err) in exprs { let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); - assert!(ret.is_err(), "at expr {e}") + assert!(ret.is_err(), "at expr {e}"); + assert_eq!(ret.unwrap_err(), err, "at expr {e}"); } } } From 0fbe5fc6c15343c260dfea8d99f394dd1334959f Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Mon, 17 Jun 2024 16:08:44 +0800 Subject: [PATCH 07/10] fix: fsm with more checks Signed-off-by: cutecutecat --- src/utils/parse.rs | 308 ++++++++++++++++++++++++++++----------------- 1 file changed, 196 insertions(+), 112 deletions(-) diff --git a/src/utils/parse.rs b/src/utils/parse.rs index 984c235f7..79b1fa47b 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use num_traits::Zero; use thiserror::Error; @@ -17,6 +19,10 @@ pub enum ParseVectorError { BadParsing { position: usize }, #[error("Index out of bounds: the dim is {dims} but the index is {index}")] OutOfBound { dims: usize, index: usize }, + #[error("The dimension should be {min} < dim < {max}, but it is actually {dims}")] + InvalidDimension { dims: usize, min: usize, max: usize }, + #[error("Indexes need to be unique, but there are more than one same index {index}")] + IndexConflict { index: usize }, } #[inline(always)] @@ -90,14 +96,14 @@ where #[derive(PartialEq)] enum ParseState { Start, + LeftBracket, Index, Value, + Splitter, Comma, - Colon, - End, + Length, } -// Index -> Colon -> Value -> Comma #[inline(always)] pub fn parse_pgvector_svector( input: &[u8], @@ -110,70 +116,45 @@ where if input.is_empty() { return Err(ParseVectorError::EmptyString {}); } - let mut dims: usize = usize::MAX; - let left = 'a: { - for position in 0..input.len() - 1 { - match input[position] { - b'{' => break 'a position, - b' ' => continue, - _ => return Err(ParseVectorError::BadCharacter { position }), - } - } - return Err(ParseVectorError::BadParentheses { character: '{' }); - }; let mut token: ArrayVec = ArrayVec::new(); - let right = 'a: { - for position in (1..input.len()).rev() { - match input[position] { - b'0'..=b'9' => { - if token.try_push(input[position]).is_err() { - return Err(ParseVectorError::TooLongNumber { position }); - } - } - b'/' => { - token.reverse(); - let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; - // two `dims` are found - if dims != usize::MAX { - return Err(ParseVectorError::BadCharacter { position }); - } - dims = s - .parse::() - .map_err(|_| ParseVectorError::BadParsing { position })?; - } - b'}' => { - token.clear(); - break 'a position; - } - b' ' => continue, - _ => return Err(ParseVectorError::BadCharacter { position }), - } - } - return Err(ParseVectorError::BadParentheses { character: '}' }); - }; - // `dims` is not found - if dims == usize::MAX { - return Err(ParseVectorError::BadCharacter { - position: input.len(), - }); - } + let mut indexes = Vec::::new(); let mut values = Vec::::new(); + let mut all_indexes = Vec::::new(); let mut index: u32 = u32::MAX; let mut state = ParseState::Start; - let mut position = left; + let mut position = 0; loop { - if position == right { - let end_with_number = state == ParseState::Value && !token.is_empty(); - let end_with_comma = state == ParseState::Index && token.is_empty(); - if end_with_number || end_with_comma { - state = ParseState::End; - } else { - return Err(ParseVectorError::BadCharacter { position }); - } + if position >= input.len() { + break; } match state { + ParseState::Start => { + let c = input[position]; + match c { + b'{' => { + state = ParseState::LeftBracket; + } + b' ' => {} + _ => return Err(ParseVectorError::BadCharacter { position }), + } + } + ParseState::LeftBracket => { + let c = input[position]; + match c { + b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { + // Do not read it here, goto Index to read + position -= 1; + state = ParseState::Index; + } + b'}' => { + state = ParseState::Splitter; + } + b' ' => {} + _ => return Err(ParseVectorError::BadCharacter { position }), + } + } ParseState::Index => { let c = input[position]; match c { @@ -184,12 +165,19 @@ where if token.try_push(c).is_err() { return Err(ParseVectorError::TooLongNumber { position }); } - position += 1; } b':' => { - state = ParseState::Colon; + if token.is_empty() { + return Err(ParseVectorError::TooShortNumber { position }); + } + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + index = s + .parse::() + .map_err(|_| ParseVectorError::BadParsing { position })?; + token.clear(); + state = ParseState::Value; } - b' ' => position += 1, + b' ' => {} _ => return Err(ParseVectorError::BadCharacter { position }), } } @@ -203,62 +191,125 @@ where if token.try_push(c).is_err() { return Err(ParseVectorError::TooLongNumber { position }); } - position += 1; } b',' => { + if token.is_empty() { + return Err(ParseVectorError::TooShortNumber { position }); + } + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; + if !num.is_zero() { + indexes.push(index); + values.push(num); + } + all_indexes.push(index); + token.clear(); state = ParseState::Comma; } - b' ' => position += 1, + // Bracket ended with number + b'}' => { + if token.is_empty() { + return Err(ParseVectorError::TooShortNumber { position }); + } + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; + if !num.is_zero() { + indexes.push(index); + values.push(num); + } + all_indexes.push(index); + token.clear(); + state = ParseState::Splitter; + } + b' ' => {} _ => return Err(ParseVectorError::BadCharacter { position }), } } - e @ (ParseState::Comma | ParseState::End) => { - if !token.is_empty() { - // Safety: all bytes in `token` are ascii characters - let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; - let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; - if index as usize >= dims { - return Err(ParseVectorError::OutOfBound { - dims, - index: index as usize, - }); + ParseState::Comma => { + let c = input[position]; + match c { + b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { + // Do not read it here, goto Index to read + position -= 1; + state = ParseState::Index; } - if !num.is_zero() { - indexes.push(index); - values.push(num); + b'}' => { + // Bracket ended with comma + state = ParseState::Splitter; } - index = u32::MAX; - token.clear(); - } else if e != ParseState::End { - return Err(ParseVectorError::TooShortNumber { position }); + b' ' => {} + _ => return Err(ParseVectorError::BadCharacter { position }), } - if e == ParseState::End { - break; - } else { - state = ParseState::Index; - position += 1; + } + ParseState::Splitter => { + let c = input[position]; + match c { + b'/' => { + state = ParseState::Length; + } + _ => return Err(ParseVectorError::BadCharacter { position }), } } - ParseState::Colon => { - if !token.is_empty() { - // Safety: all bytes in `token` are ascii characters - let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; - index = s - .parse::() - .map_err(|_| ParseVectorError::BadParsing { position })?; - token.clear(); - } else { - return Err(ParseVectorError::TooShortNumber { position }); + ParseState::Length => { + let c = input[position]; + match c { + b'0'..=b'9' => { + if token.is_empty() { + token.push(b'$'); + } + if token.try_push(c).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); + } + } + _ => return Err(ParseVectorError::BadCharacter { position }), } - state = ParseState::Value; - position += 1; } - ParseState::Start => { - state = ParseState::Index; - position += 1; + } + position += 1; + } + if state != ParseState::Length { + return Err(ParseVectorError::BadParsing { + position: input.len(), + }); + } + if token.is_empty() { + return Err(ParseVectorError::TooShortNumber { position }); + } + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let dims = s + .parse::() + .map_err(|_| ParseVectorError::BadParsing { position })?; + + // Check dimension out of bound + if dims == 0 || dims >= 1048576 { + return Err(ParseVectorError::InvalidDimension { + dims, + min: 0, + max: 1048576, + }); + } + // Check index out of bound + for index in all_indexes.clone() { + if index as usize >= dims { + return Err(ParseVectorError::OutOfBound { + dims, + index: index as usize, + }); + } + } + // Check index conflicts + let mut result: HashMap = HashMap::new(); + for index in all_indexes { + if let Some(value) = result.get(&index) { + if *value == 1 { + return Err(ParseVectorError::IndexConflict { + index: index as usize, + }); } } + *result.entry(index).or_insert(0) += 1; } + let mut indices = (0..indexes.len()).collect::>(); indices.sort_by_key(|&i| &indexes[i]); let sorted_values: Vec = indices @@ -299,18 +350,35 @@ mod tests { ]); for (e, ans) in exprs { let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); - assert!(ret.is_ok(), "at expr {e}"); - assert_eq!(ret.unwrap(), ans, "at expr {e}"); + assert!(ret.is_ok(), "at expr {:?}: {:?}", e, ret); + assert_eq!(ret.unwrap(), ans, "at expr {:?}", e); } } #[test] fn test_svector_parse_reject() { let exprs: HashMap<&str, ParseVectorError> = HashMap::from([ - ("{", ParseVectorError::BadParentheses { character: '{' }), - ("}", ParseVectorError::BadParentheses { character: '{' }), + ("{", ParseVectorError::BadParsing { position: 1 }), + ("}", ParseVectorError::BadCharacter { position: 0 }), ("{:", ParseVectorError::BadCharacter { position: 1 }), (":}", ParseVectorError::BadCharacter { position: 0 }), + ( + "{}/0", + ParseVectorError::InvalidDimension { + dims: 0, + min: 0, + max: 1048576, + }, + ), + ( + "{}/1919810", + ParseVectorError::InvalidDimension { + dims: 1919810, + min: 0, + max: 1048576, + }, + ), + ("{0:1, 0:2}/1", ParseVectorError::IndexConflict { index: 0 }), ( "{0:1, 1:1.5}/1", ParseVectorError::OutOfBound { dims: 1, index: 1 }, @@ -319,25 +387,41 @@ mod tests { "{0:0, 1:0, 2:0}/2", ParseVectorError::OutOfBound { dims: 2, index: 2 }, ), + ( + "{2:0, 1:0}/2", + ParseVectorError::OutOfBound { dims: 2, index: 2 }, + ), + ( + "{2:0, 1:0, }/2", + ParseVectorError::OutOfBound { dims: 2, index: 2 }, + ), ( "{0:1, 1:2, 2:3}", - ParseVectorError::BadCharacter { position: 15 }, + ParseVectorError::BadParsing { position: 15 }, ), ( "{0:1, 1:2, 2:3", - ParseVectorError::BadCharacter { position: 12 }, + ParseVectorError::BadParsing { position: 14 }, + ), + ( + "{0:1, 1:2}/", + ParseVectorError::TooShortNumber { position: 11 }, ), - ("{0:1, 1:2}/", ParseVectorError::BadParsing { position: 10 }), ("{0}/5", ParseVectorError::BadCharacter { position: 2 }), - ("{0:}/5", ParseVectorError::BadCharacter { position: 3 }), - ("{:0}/5", ParseVectorError::TooShortNumber { position: 1 }), + ("{0:}/5", ParseVectorError::TooShortNumber { position: 3 }), + ("{:0}/5", ParseVectorError::BadCharacter { position: 1 }), ( "{0:, 1:2}/5", ParseVectorError::TooShortNumber { position: 3 }, ), ("{0:1, 1}/5", ParseVectorError::BadCharacter { position: 7 }), ("/2", ParseVectorError::BadCharacter { position: 0 }), - ("{}/1/2", ParseVectorError::BadCharacter { position: 2 }), + ("{}/1/2", ParseVectorError::BadCharacter { position: 4 }), + ( + "{0:1, 1:2}/4/2", + ParseVectorError::BadCharacter { position: 12 }, + ), + ("{}/-4", ParseVectorError::BadCharacter { position: 3 }), ( "{1,2,3,4}/5", ParseVectorError::BadCharacter { position: 2 }, @@ -345,8 +429,8 @@ mod tests { ]); for (e, err) in exprs { let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); - assert!(ret.is_err(), "at expr {e}"); - assert_eq!(ret.unwrap_err(), err, "at expr {e}"); + assert!(ret.is_err(), "at expr {:?}: {:?}", e, ret); + assert_eq!(ret.unwrap_err(), err, "at expr {:?}", e); } } } From 8a8561e8765c2c677caa2cde0cdb8bf080c11b0b Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Tue, 18 Jun 2024 10:22:11 +0800 Subject: [PATCH 08/10] fix: by comments Signed-off-by: cutecutecat --- src/datatype/text_svecf32.rs | 23 ++- src/error.rs | 14 ++ src/utils/parse.rs | 365 ++++++++++++++--------------------- 3 files changed, 178 insertions(+), 224 deletions(-) diff --git a/src/datatype/text_svecf32.rs b/src/datatype/text_svecf32.rs index e32bddd83..ed64c3716 100644 --- a/src/datatype/text_svecf32.rs +++ b/src/datatype/text_svecf32.rs @@ -5,10 +5,11 @@ use base::scalar::*; use base::vector::*; use pgrx::pg_sys::Oid; use std::ffi::{CStr, CString}; +use std::fmt::Write; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output { - use crate::utils::parse::parse_pgvector_svector; + use crate::utils::parse::{parse_pgvector_svector, svector_filter_nonzero}; let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::().ok()); match v { Err(e) => { @@ -16,7 +17,13 @@ fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output { } Ok((indexes, values, dims)) => { check_value_dims_1048575(dims); - SVecf32Output::new(SVecf32Borrowed::new(dims as u32, &indexes, &values)) + check_index_in_bound(&indexes, dims); + let (non_zero_indexes, non_zero_values) = svector_filter_nonzero(&indexes, &values); + SVecf32Output::new(SVecf32Borrowed::new( + dims as u32, + &non_zero_indexes, + &non_zero_values, + )) } } } @@ -27,16 +34,16 @@ fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString { let mut buffer = String::new(); buffer.push('{'); let svec = vector.for_borrow(); - let mut need_splitter = true; + let mut need_splitter = false; for (&index, &value) in svec.indexes().iter().zip(svec.values().iter()) { match need_splitter { - true => { - buffer.push_str(format!("{}:{}", index, value).as_str()); - need_splitter = false; + false => { + write!(buffer, "{}:{}", index, value).unwrap(); + need_splitter = true; } - false => buffer.push_str(format!(", {}:{}", index, value).as_str()), + true => write!(buffer, ", {}:{}", index, value).unwrap(), } } - buffer.push_str(format!("}}/{}", dims).as_str()); + write!(buffer, "}}/{}", dims).unwrap(); CString::new(buffer).unwrap() } diff --git a/src/error.rs b/src/error.rs index fd368de4c..99a8a34c0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -69,6 +69,20 @@ ADVICE: Check if dimensions of the vector are among 1 and 1_048_575." NonZeroU32::new(dims as u32).unwrap() } +pub fn check_index_in_bound(indexes: &[u32], dims: usize) -> NonZeroU32 { + let mut last: u32 = 0; + for (i, index) in indexes.iter().enumerate() { + if i > 0 && last == *index { + error!("Indexes need to be unique, but there are more than one same index {index}") + } + if *index >= dims as u32 { + error!("Index out of bounds: the dim is {dims} but the index is {index}"); + } + last = *index; + } + NonZeroU32::new(dims as u32).unwrap() +} + pub fn bad_literal(hint: &str) -> ! { error!( "\ diff --git a/src/utils/parse.rs b/src/utils/parse.rs index 79b1fa47b..cfeb84fe9 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use num_traits::Zero; use thiserror::Error; @@ -17,12 +15,6 @@ pub enum ParseVectorError { TooShortNumber { position: usize }, #[error("Bad parsing at position {position}")] BadParsing { position: usize }, - #[error("Index out of bounds: the dim is {dims} but the index is {index}")] - OutOfBound { dims: usize, index: usize }, - #[error("The dimension should be {min} < dim < {max}, but it is actually {dims}")] - InvalidDimension { dims: usize, min: usize, max: usize }, - #[error("Indexes need to be unique, but there are more than one same index {index}")] - IndexConflict { index: usize }, } #[inline(always)] @@ -93,7 +85,7 @@ where Ok(vector) } -#[derive(PartialEq)] +#[derive(PartialEq, Debug)] enum ParseState { Start, LeftBracket, @@ -104,6 +96,26 @@ enum ParseState { Length, } +#[inline(always)] +pub fn svector_filter_nonzero( + indexes: &[u32], + values: &[T], +) -> (Vec, Vec) { + let non_zero_indexes: Vec = indexes + .iter() + .enumerate() + .filter(|(i, _)| values.get(*i).unwrap() != &T::zero()) + .map(|(_, x)| *x) + .collect(); + let non_zero_values: Vec = indexes + .iter() + .enumerate() + .filter(|(i, _)| values.get(*i).unwrap() != &T::zero()) + .map(|(i, _)| values.get(i).unwrap().clone()) + .collect(); + (non_zero_indexes, non_zero_values) +} + #[inline(always)] pub fn parse_pgvector_svector( input: &[u8], @@ -117,155 +129,87 @@ where return Err(ParseVectorError::EmptyString {}); } let mut token: ArrayVec = ArrayVec::new(); - let mut indexes = Vec::::new(); let mut values = Vec::::new(); - let mut all_indexes = Vec::::new(); - let mut index: u32 = u32::MAX; let mut state = ParseState::Start; - let mut position = 0; - loop { - if position >= input.len() { - break; - } - match state { - ParseState::Start => { - let c = input[position]; - match c { - b'{' => { - state = ParseState::LeftBracket; - } - b' ' => {} - _ => return Err(ParseVectorError::BadCharacter { position }), - } + for (position, char) in input.iter().enumerate() { + let c = *char; + match (&state, c) { + (_, b' ') => {} + (ParseState::Start, b'{') => { + state = ParseState::LeftBracket; } - ParseState::LeftBracket => { - let c = input[position]; - match c { - b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { - // Do not read it here, goto Index to read - position -= 1; - state = ParseState::Index; - } - b'}' => { - state = ParseState::Splitter; - } - b' ' => {} - _ => return Err(ParseVectorError::BadCharacter { position }), + ( + ParseState::LeftBracket | ParseState::Index | ParseState::Comma, + b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-', + ) => { + if token.is_empty() { + token.push(b'$'); + } + if token.try_push(c).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); } + state = ParseState::Index; } - ParseState::Index => { - let c = input[position]; - match c { - b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { - if token.is_empty() { - token.push(b'$'); - } - if token.try_push(c).is_err() { - return Err(ParseVectorError::TooLongNumber { position }); - } - } - b':' => { - if token.is_empty() { - return Err(ParseVectorError::TooShortNumber { position }); - } - let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; - index = s - .parse::() - .map_err(|_| ParseVectorError::BadParsing { position })?; - token.clear(); - state = ParseState::Value; - } - b' ' => {} - _ => return Err(ParseVectorError::BadCharacter { position }), + (ParseState::LeftBracket | ParseState::Comma, b'}') => { + state = ParseState::Splitter; + } + (ParseState::Index, b':') => { + if token.is_empty() { + return Err(ParseVectorError::TooShortNumber { position }); } + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let index = s + .parse::() + .map_err(|_| ParseVectorError::BadParsing { position })?; + indexes.push(index); + token.clear(); + state = ParseState::Value; } - ParseState::Value => { - let c = input[position]; - match c { - b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { - if token.is_empty() { - token.push(b'$'); - } - if token.try_push(c).is_err() { - return Err(ParseVectorError::TooLongNumber { position }); - } - } - b',' => { - if token.is_empty() { - return Err(ParseVectorError::TooShortNumber { position }); - } - let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; - let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; - if !num.is_zero() { - indexes.push(index); - values.push(num); - } - all_indexes.push(index); - token.clear(); - state = ParseState::Comma; - } - // Bracket ended with number - b'}' => { - if token.is_empty() { - return Err(ParseVectorError::TooShortNumber { position }); - } - let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; - let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; - if !num.is_zero() { - indexes.push(index); - values.push(num); - } - all_indexes.push(index); - token.clear(); - state = ParseState::Splitter; - } - b' ' => {} - _ => return Err(ParseVectorError::BadCharacter { position }), + (ParseState::Value, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => { + if token.is_empty() { + token.push(b'$'); + } + if token.try_push(c).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); } } - ParseState::Comma => { - let c = input[position]; - match c { - b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { - // Do not read it here, goto Index to read - position -= 1; - state = ParseState::Index; - } - b'}' => { - // Bracket ended with comma - state = ParseState::Splitter; - } - b' ' => {} - _ => return Err(ParseVectorError::BadCharacter { position }), + (ParseState::Value, b',') => { + if token.is_empty() { + return Err(ParseVectorError::TooShortNumber { position }); } + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; + values.push(num); + token.clear(); + state = ParseState::Comma; } - ParseState::Splitter => { - let c = input[position]; - match c { - b'/' => { - state = ParseState::Length; - } - _ => return Err(ParseVectorError::BadCharacter { position }), + (ParseState::Value, b'}') => { + if token.is_empty() { + return Err(ParseVectorError::TooShortNumber { position }); } + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; + values.push(num); + token.clear(); + state = ParseState::Splitter; + } + (ParseState::Splitter, b'/') => { + state = ParseState::Length; } - ParseState::Length => { - let c = input[position]; - match c { - b'0'..=b'9' => { - if token.is_empty() { - token.push(b'$'); - } - if token.try_push(c).is_err() { - return Err(ParseVectorError::TooLongNumber { position }); - } - } - _ => return Err(ParseVectorError::BadCharacter { position }), + (ParseState::Length, b'0'..=b'9') => { + if token.is_empty() { + token.push(b'$'); + } + if token.try_push(c).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); } } + (_, _) => { + return Err(ParseVectorError::BadCharacter { position }); + } } - position += 1; } if state != ParseState::Length { return Err(ParseVectorError::BadParsing { @@ -273,42 +217,16 @@ where }); } if token.is_empty() { - return Err(ParseVectorError::TooShortNumber { position }); + return Err(ParseVectorError::TooShortNumber { + position: input.len(), + }); } let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; let dims = s .parse::() - .map_err(|_| ParseVectorError::BadParsing { position })?; - - // Check dimension out of bound - if dims == 0 || dims >= 1048576 { - return Err(ParseVectorError::InvalidDimension { - dims, - min: 0, - max: 1048576, - }); - } - // Check index out of bound - for index in all_indexes.clone() { - if index as usize >= dims { - return Err(ParseVectorError::OutOfBound { - dims, - index: index as usize, - }); - } - } - // Check index conflicts - let mut result: HashMap = HashMap::new(); - for index in all_indexes { - if let Some(value) = result.get(&index) { - if *value == 1 { - return Err(ParseVectorError::IndexConflict { - index: index as usize, - }); - } - } - *result.entry(index).or_insert(0) += 1; - } + .map_err(|_| ParseVectorError::BadParsing { + position: input.len(), + })?; let mut indices = (0..indexes.len()).collect::>(); indices.sort_by_key(|&i| &indexes[i]); @@ -317,20 +235,19 @@ where .map(|i| values.get(*i).unwrap().clone()) .collect(); indexes.sort(); + Ok((indexes, sorted_values, dims)) } #[cfg(test)] mod tests { - use std::collections::HashMap; - use base::scalar::F32; use super::*; #[test] fn test_svector_parse_accept() { - let exprs: HashMap<&str, (Vec, Vec, usize)> = HashMap::from([ + let exprs: Vec<(&str, (Vec, Vec, usize))> = vec![ ("{}/1", (vec![], vec![], 1)), ("{0:1}/1", (vec![0], vec![F32(1.0)], 1)), ( @@ -342,59 +259,33 @@ mod tests { "{0:+3, 2:-4.1}/3", (vec![0, 2], vec![F32(3.0), F32(-4.1)], 3), ), - ("{0:0, 1:0, 2:0}/3", (vec![], vec![], 3)), + ( + "{0:0, 1:0, 2:0}/3", + (vec![0, 1, 2], vec![F32(0.0), F32(0.0), F32(0.0)], 3), + ), ( "{3:3, 2:2, 1:1, 0:0}/4", - (vec![1, 2, 3], vec![F32(1.0), F32(2.0), F32(3.0)], 4), + ( + vec![0, 1, 2, 3], + vec![F32(0.0), F32(1.0), F32(2.0), F32(3.0)], + 4, + ), ), - ]); - for (e, ans) in exprs { + ]; + for (e, parsed) in exprs { let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); assert!(ret.is_ok(), "at expr {:?}: {:?}", e, ret); - assert_eq!(ret.unwrap(), ans, "at expr {:?}", e); + assert_eq!(ret.unwrap(), parsed, "parsed at expr {:?}", e); } } #[test] fn test_svector_parse_reject() { - let exprs: HashMap<&str, ParseVectorError> = HashMap::from([ + let exprs: Vec<(&str, ParseVectorError)> = vec![ ("{", ParseVectorError::BadParsing { position: 1 }), ("}", ParseVectorError::BadCharacter { position: 0 }), ("{:", ParseVectorError::BadCharacter { position: 1 }), (":}", ParseVectorError::BadCharacter { position: 0 }), - ( - "{}/0", - ParseVectorError::InvalidDimension { - dims: 0, - min: 0, - max: 1048576, - }, - ), - ( - "{}/1919810", - ParseVectorError::InvalidDimension { - dims: 1919810, - min: 0, - max: 1048576, - }, - ), - ("{0:1, 0:2}/1", ParseVectorError::IndexConflict { index: 0 }), - ( - "{0:1, 1:1.5}/1", - ParseVectorError::OutOfBound { dims: 1, index: 1 }, - ), - ( - "{0:0, 1:0, 2:0}/2", - ParseVectorError::OutOfBound { dims: 2, index: 2 }, - ), - ( - "{2:0, 1:0}/2", - ParseVectorError::OutOfBound { dims: 2, index: 2 }, - ), - ( - "{2:0, 1:0, }/2", - ParseVectorError::OutOfBound { dims: 2, index: 2 }, - ), ( "{0:1, 1:2, 2:3}", ParseVectorError::BadParsing { position: 15 }, @@ -426,11 +317,53 @@ mod tests { "{1,2,3,4}/5", ParseVectorError::BadCharacter { position: 2 }, ), - ]); + ]; for (e, err) in exprs { let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); assert!(ret.is_err(), "at expr {:?}: {:?}", e, ret); - assert_eq!(ret.unwrap_err(), err, "at expr {:?}", e); + assert_eq!(ret.unwrap_err(), err, "parsed at expr {:?}", e); + } + } + + #[test] + fn test_svector_parse_filter() { + let exprs: Vec<(&str, (Vec, Vec, usize), (Vec, Vec))> = vec![ + ("{}/0", (vec![], vec![], 0), (vec![], vec![])), + ("{}/1919810", (vec![], vec![], 1919810), (vec![], vec![])), + ( + "{0:1, 0:2}/1", + (vec![0, 0], vec![F32(1.0), F32(2.0)], 1), + (vec![0, 0], vec![F32(1.0), F32(2.0)]), + ), + ( + "{0:1, 1:1.5}/1", + (vec![0, 1], vec![F32(1.0), F32(1.5)], 1), + (vec![0, 1], vec![F32(1.0), F32(1.5)]), + ), + ( + "{0:0, 1:0, 2:0}/2", + (vec![0, 1, 2], vec![F32(0.0), F32(0.0), F32(0.0)], 2), + (vec![], vec![]), + ), + ( + "{2:0, 1:0}/2", + (vec![1, 2], vec![F32(0.0), F32(0.0)], 2), + (vec![], vec![]), + ), + ( + "{2:0, 1:0, }/2", + (vec![1, 2], vec![F32(0.0), F32(0.0)], 2), + (vec![], vec![]), + ), + ]; + for (e, parsed, filtered) in exprs { + let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); + assert!(ret.is_ok(), "at expr {:?}: {:?}", e, ret); + assert_eq!(ret.unwrap(), parsed, "parsed at expr {:?}", e); + + let (indexes, values, _) = parsed; + let nonzero = svector_filter_nonzero(&indexes, &values); + assert_eq!(nonzero, filtered, "filtered at expr {:?}", e); } } } From 94312efc719686f9e0e25b956eb1c22c2c32ffc6 Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Tue, 18 Jun 2024 17:06:13 +0800 Subject: [PATCH 09/10] fix: by comments Signed-off-by: cutecutecat --- src/datatype/text_svecf32.rs | 11 +-- src/error.rs | 8 +- src/utils/parse.rs | 160 +++++++++++++++++------------------ 3 files changed, 89 insertions(+), 90 deletions(-) diff --git a/src/datatype/text_svecf32.rs b/src/datatype/text_svecf32.rs index ed64c3716..59c3838bd 100644 --- a/src/datatype/text_svecf32.rs +++ b/src/datatype/text_svecf32.rs @@ -9,20 +9,21 @@ use std::fmt::Write; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output { - use crate::utils::parse::{parse_pgvector_svector, svector_filter_nonzero}; + use crate::utils::parse::{parse_pgvector_svector, svector_filter_nonzero, svector_sorted}; let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::().ok()); match v { Err(e) => { bad_literal(&e.to_string()); } Ok((indexes, values, dims)) => { + let (mut sorted_indexes, mut sorted_values) = svector_sorted(&indexes, &values); check_value_dims_1048575(dims); - check_index_in_bound(&indexes, dims); - let (non_zero_indexes, non_zero_values) = svector_filter_nonzero(&indexes, &values); + check_index_in_bound(&sorted_indexes, dims); + svector_filter_nonzero(&mut sorted_indexes, &mut sorted_values); SVecf32Output::new(SVecf32Borrowed::new( dims as u32, - &non_zero_indexes, - &non_zero_values, + &sorted_indexes, + &sorted_values, )) } } diff --git a/src/error.rs b/src/error.rs index 99a8a34c0..6a7226858 100644 --- a/src/error.rs +++ b/src/error.rs @@ -70,15 +70,15 @@ ADVICE: Check if dimensions of the vector are among 1 and 1_048_575." } pub fn check_index_in_bound(indexes: &[u32], dims: usize) -> NonZeroU32 { - let mut last: u32 = 0; - for (i, index) in indexes.iter().enumerate() { - if i > 0 && last == *index { + let mut last: Option = None; + for index in indexes { + if last == Some(*index) { error!("Indexes need to be unique, but there are more than one same index {index}") } if *index >= dims as u32 { error!("Index out of bounds: the dim is {dims} but the index is {index}"); } - last = *index; + last = Some(*index); } NonZeroU32::new(dims as u32).unwrap() } diff --git a/src/utils/parse.rs b/src/utils/parse.rs index cfeb84fe9..500de3232 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -85,35 +85,54 @@ where Ok(vector) } -#[derive(PartialEq, Debug)] +#[derive(PartialEq, Debug, Clone)] enum ParseState { Start, LeftBracket, Index, + Colon, Value, - Splitter, Comma, - Length, + RightBracket, + Splitter, + Dims, } #[inline(always)] -pub fn svector_filter_nonzero( +pub fn svector_sorted( indexes: &[u32], values: &[T], ) -> (Vec, Vec) { - let non_zero_indexes: Vec = indexes - .iter() - .enumerate() - .filter(|(i, _)| values.get(*i).unwrap() != &T::zero()) - .map(|(_, x)| *x) - .collect(); - let non_zero_values: Vec = indexes - .iter() - .enumerate() - .filter(|(i, _)| values.get(*i).unwrap() != &T::zero()) - .map(|(i, _)| values.get(i).unwrap().clone()) - .collect(); - (non_zero_indexes, non_zero_values) + let mut indices = (0..indexes.len()).collect::>(); + indices.sort_by_key(|&i| &indexes[i]); + + let mut sorted_indexes: Vec = Vec::with_capacity(indexes.len()); + let mut sorted_values: Vec = Vec::with_capacity(indexes.len()); + for i in indices { + sorted_indexes.push(*indexes.get(i).unwrap()); + sorted_values.push(values.get(i).unwrap().clone()); + } + (sorted_indexes, sorted_values) +} + +#[inline(always)] +pub fn svector_filter_nonzero( + indexes: &mut Vec, + values: &mut Vec, +) { + // Index must be sorted! + let mut i = 0; + let mut j = 0; + while j < values.len() { + if !values[j].is_zero() { + indexes[i] = indexes[j]; + values[i] = values[j].clone(); + i += 1; + } + j += 1; + } + indexes.truncate(i); + values.truncate(i); } #[inline(always)] @@ -133,110 +152,82 @@ where let mut values = Vec::::new(); let mut state = ParseState::Start; - for (position, char) in input.iter().enumerate() { - let c = *char; - match (&state, c) { - (_, b' ') => {} - (ParseState::Start, b'{') => { - state = ParseState::LeftBracket; - } + for (position, c) in input.iter().copied().enumerate() { + state = match (&state, c) { + (_, b' ') => state, + (ParseState::Start, b'{') => ParseState::LeftBracket, ( ParseState::LeftBracket | ParseState::Index | ParseState::Comma, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-', ) => { - if token.is_empty() { - token.push(b'$'); - } if token.try_push(c).is_err() { return Err(ParseVectorError::TooLongNumber { position }); } - state = ParseState::Index; + ParseState::Index } - (ParseState::LeftBracket | ParseState::Comma, b'}') => { - state = ParseState::Splitter; + (ParseState::Colon, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => { + if token.try_push(c).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); + } + ParseState::Value } + (ParseState::LeftBracket | ParseState::Comma, b'}') => ParseState::RightBracket, (ParseState::Index, b':') => { - if token.is_empty() { - return Err(ParseVectorError::TooShortNumber { position }); - } - let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; let index = s .parse::() .map_err(|_| ParseVectorError::BadParsing { position })?; indexes.push(index); token.clear(); - state = ParseState::Value; + ParseState::Colon } (ParseState::Value, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => { - if token.is_empty() { - token.push(b'$'); - } if token.try_push(c).is_err() { return Err(ParseVectorError::TooLongNumber { position }); } + ParseState::Value } (ParseState::Value, b',') => { - if token.is_empty() { - return Err(ParseVectorError::TooShortNumber { position }); - } - let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; values.push(num); token.clear(); - state = ParseState::Comma; + ParseState::Comma } (ParseState::Value, b'}') => { if token.is_empty() { return Err(ParseVectorError::TooShortNumber { position }); } - let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; values.push(num); token.clear(); - state = ParseState::Splitter; - } - (ParseState::Splitter, b'/') => { - state = ParseState::Length; + ParseState::RightBracket } - (ParseState::Length, b'0'..=b'9') => { - if token.is_empty() { - token.push(b'$'); - } + (ParseState::RightBracket, b'/') => ParseState::Splitter, + (ParseState::Dims | ParseState::Splitter, b'0'..=b'9') => { if token.try_push(c).is_err() { return Err(ParseVectorError::TooLongNumber { position }); } + ParseState::Dims } (_, _) => { return Err(ParseVectorError::BadCharacter { position }); } } } - if state != ParseState::Length { + if state != ParseState::Dims { return Err(ParseVectorError::BadParsing { position: input.len(), }); } - if token.is_empty() { - return Err(ParseVectorError::TooShortNumber { - position: input.len(), - }); - } - let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; let dims = s .parse::() .map_err(|_| ParseVectorError::BadParsing { position: input.len(), })?; - - let mut indices = (0..indexes.len()).collect::>(); - indices.sort_by_key(|&i| &indexes[i]); - let sorted_values: Vec = indices - .iter() - .map(|i| values.get(*i).unwrap().clone()) - .collect(); - indexes.sort(); - - Ok((indexes, sorted_values, dims)) + Ok((indexes, values, dims)) } #[cfg(test)] @@ -266,8 +257,8 @@ mod tests { ( "{3:3, 2:2, 1:1, 0:0}/4", ( - vec![0, 1, 2, 3], - vec![F32(0.0), F32(1.0), F32(2.0), F32(3.0)], + vec![3, 2, 1, 0], + vec![F32(3.0), F32(2.0), F32(1.0), F32(0.0)], 4, ), ), @@ -294,16 +285,13 @@ mod tests { "{0:1, 1:2, 2:3", ParseVectorError::BadParsing { position: 14 }, ), - ( - "{0:1, 1:2}/", - ParseVectorError::TooShortNumber { position: 11 }, - ), + ("{0:1, 1:2}/", ParseVectorError::BadParsing { position: 11 }), ("{0}/5", ParseVectorError::BadCharacter { position: 2 }), - ("{0:}/5", ParseVectorError::TooShortNumber { position: 3 }), + ("{0:}/5", ParseVectorError::BadCharacter { position: 3 }), ("{:0}/5", ParseVectorError::BadCharacter { position: 1 }), ( "{0:, 1:2}/5", - ParseVectorError::TooShortNumber { position: 3 }, + ParseVectorError::BadCharacter { position: 3 }, ), ("{0:1, 1}/5", ParseVectorError::BadCharacter { position: 7 }), ("/2", ParseVectorError::BadCharacter { position: 0 }), @@ -347,14 +335,23 @@ mod tests { ), ( "{2:0, 1:0}/2", - (vec![1, 2], vec![F32(0.0), F32(0.0)], 2), + (vec![2, 1], vec![F32(0.0), F32(0.0)], 2), (vec![], vec![]), ), ( "{2:0, 1:0, }/2", - (vec![1, 2], vec![F32(0.0), F32(0.0)], 2), + (vec![2, 1], vec![F32(0.0), F32(0.0)], 2), (vec![], vec![]), ), + ( + "{3:2, 2:1, 1:0, 0:-1}/4", + ( + vec![3, 2, 1, 0], + vec![F32(2.0), F32(1.0), F32(0.0), F32(-1.0)], + 4, + ), + (vec![0, 2, 3], vec![F32(-1.0), F32(1.0), F32(2.0)]), + ), ]; for (e, parsed, filtered) in exprs { let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); @@ -362,8 +359,9 @@ mod tests { assert_eq!(ret.unwrap(), parsed, "parsed at expr {:?}", e); let (indexes, values, _) = parsed; - let nonzero = svector_filter_nonzero(&indexes, &values); - assert_eq!(nonzero, filtered, "filtered at expr {:?}", e); + let (mut indexes, mut values) = svector_sorted(&indexes, &values); + svector_filter_nonzero(&mut indexes, &mut values); + assert_eq!((indexes, values), filtered, "filtered at expr {:?}", e); } } } From 294f0c88df1969ccc7f907ef45c6eb6b3949b3c4 Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Tue, 18 Jun 2024 18:41:32 +0800 Subject: [PATCH 10/10] fix: remove funcs Signed-off-by: cutecutecat --- src/datatype/text_svecf32.rs | 67 +++++++++++++++++++++++---- src/error.rs | 14 ------ src/utils/parse.rs | 89 ------------------------------------ 3 files changed, 57 insertions(+), 113 deletions(-) diff --git a/src/datatype/text_svecf32.rs b/src/datatype/text_svecf32.rs index 59c3838bd..ea94c9e54 100644 --- a/src/datatype/text_svecf32.rs +++ b/src/datatype/text_svecf32.rs @@ -1,3 +1,6 @@ +use num_traits::Zero; +use pgrx::error; + use super::memory_svecf32::SVecf32Output; use crate::datatype::memory_svecf32::SVecf32Input; use crate::error::*; @@ -9,22 +12,66 @@ use std::fmt::Write; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output { - use crate::utils::parse::{parse_pgvector_svector, svector_filter_nonzero, svector_sorted}; + use crate::utils::parse::parse_pgvector_svector; let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::().ok()); match v { Err(e) => { bad_literal(&e.to_string()); } - Ok((indexes, values, dims)) => { - let (mut sorted_indexes, mut sorted_values) = svector_sorted(&indexes, &values); + Ok((mut indexes, mut values, dims)) => { check_value_dims_1048575(dims); - check_index_in_bound(&sorted_indexes, dims); - svector_filter_nonzero(&mut sorted_indexes, &mut sorted_values); - SVecf32Output::new(SVecf32Borrowed::new( - dims as u32, - &sorted_indexes, - &sorted_values, - )) + // is_sorted + if !indexes.windows(2).all(|i| i[0] <= i[1]) { + assert_eq!(indexes.len(), values.len()); + let n = indexes.len(); + let mut permutation = (0..n).collect::>(); + permutation.sort_unstable_by_key(|&i| &indexes[i]); + for i in 0..n { + if i == permutation[i] || usize::MAX == permutation[i] { + continue; + } + let index = indexes[i]; + let value = values[i]; + let mut j = i; + while i != permutation[j] { + let next = permutation[j]; + indexes[j] = indexes[permutation[j]]; + values[j] = values[permutation[j]]; + permutation[j] = usize::MAX; + j = next; + } + indexes[j] = index; + values[j] = value; + permutation[j] = usize::MAX; + } + } + let mut last: Option = None; + for index in indexes.clone() { + if last == Some(index) { + error!( + "Indexes need to be unique, but there are more than one same index {index}" + ) + } + if last >= Some(dims as u32) { + error!("Index out of bounds: the dim is {dims} but the index is {index}"); + } + last = Some(index); + { + let mut i = 0; + let mut j = 0; + while j < values.len() { + if !values[j].is_zero() { + indexes[i] = indexes[j]; + values[i] = values[j]; + i += 1; + } + j += 1; + } + indexes.truncate(i); + values.truncate(i); + } + } + SVecf32Output::new(SVecf32Borrowed::new(dims as u32, &indexes, &values)) } } } diff --git a/src/error.rs b/src/error.rs index 6a7226858..fd368de4c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -69,20 +69,6 @@ ADVICE: Check if dimensions of the vector are among 1 and 1_048_575." NonZeroU32::new(dims as u32).unwrap() } -pub fn check_index_in_bound(indexes: &[u32], dims: usize) -> NonZeroU32 { - let mut last: Option = None; - for index in indexes { - if last == Some(*index) { - error!("Indexes need to be unique, but there are more than one same index {index}") - } - if *index >= dims as u32 { - error!("Index out of bounds: the dim is {dims} but the index is {index}"); - } - last = Some(*index); - } - NonZeroU32::new(dims as u32).unwrap() -} - pub fn bad_literal(hint: &str) -> ! { error!( "\ diff --git a/src/utils/parse.rs b/src/utils/parse.rs index 500de3232..3eac1afbf 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -98,43 +98,6 @@ enum ParseState { Dims, } -#[inline(always)] -pub fn svector_sorted( - indexes: &[u32], - values: &[T], -) -> (Vec, Vec) { - let mut indices = (0..indexes.len()).collect::>(); - indices.sort_by_key(|&i| &indexes[i]); - - let mut sorted_indexes: Vec = Vec::with_capacity(indexes.len()); - let mut sorted_values: Vec = Vec::with_capacity(indexes.len()); - for i in indices { - sorted_indexes.push(*indexes.get(i).unwrap()); - sorted_values.push(values.get(i).unwrap().clone()); - } - (sorted_indexes, sorted_values) -} - -#[inline(always)] -pub fn svector_filter_nonzero( - indexes: &mut Vec, - values: &mut Vec, -) { - // Index must be sorted! - let mut i = 0; - let mut j = 0; - while j < values.len() { - if !values[j].is_zero() { - indexes[i] = indexes[j]; - values[i] = values[j].clone(); - i += 1; - } - j += 1; - } - indexes.truncate(i); - values.truncate(i); -} - #[inline(always)] pub fn parse_pgvector_svector( input: &[u8], @@ -312,56 +275,4 @@ mod tests { assert_eq!(ret.unwrap_err(), err, "parsed at expr {:?}", e); } } - - #[test] - fn test_svector_parse_filter() { - let exprs: Vec<(&str, (Vec, Vec, usize), (Vec, Vec))> = vec![ - ("{}/0", (vec![], vec![], 0), (vec![], vec![])), - ("{}/1919810", (vec![], vec![], 1919810), (vec![], vec![])), - ( - "{0:1, 0:2}/1", - (vec![0, 0], vec![F32(1.0), F32(2.0)], 1), - (vec![0, 0], vec![F32(1.0), F32(2.0)]), - ), - ( - "{0:1, 1:1.5}/1", - (vec![0, 1], vec![F32(1.0), F32(1.5)], 1), - (vec![0, 1], vec![F32(1.0), F32(1.5)]), - ), - ( - "{0:0, 1:0, 2:0}/2", - (vec![0, 1, 2], vec![F32(0.0), F32(0.0), F32(0.0)], 2), - (vec![], vec![]), - ), - ( - "{2:0, 1:0}/2", - (vec![2, 1], vec![F32(0.0), F32(0.0)], 2), - (vec![], vec![]), - ), - ( - "{2:0, 1:0, }/2", - (vec![2, 1], vec![F32(0.0), F32(0.0)], 2), - (vec![], vec![]), - ), - ( - "{3:2, 2:1, 1:0, 0:-1}/4", - ( - vec![3, 2, 1, 0], - vec![F32(2.0), F32(1.0), F32(0.0), F32(-1.0)], - 4, - ), - (vec![0, 2, 3], vec![F32(-1.0), F32(1.0), F32(2.0)]), - ), - ]; - for (e, parsed, filtered) in exprs { - let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::().ok()); - assert!(ret.is_ok(), "at expr {:?}: {:?}", e, ret); - assert_eq!(ret.unwrap(), parsed, "parsed at expr {:?}", e); - - let (indexes, values, _) = parsed; - let (mut indexes, mut values) = svector_sorted(&indexes, &values); - svector_filter_nonzero(&mut indexes, &mut values); - assert_eq!((indexes, values), filtered, "filtered at expr {:?}", e); - } - } }